1use super::builder_utils::build_optional_adapter;
8use oar_ocr_core::core::OCRError;
9use oar_ocr_core::core::config::OrtSessionConfig;
10use oar_ocr_core::core::traits::OrtConfigurable;
11use oar_ocr_core::core::traits::adapter::{AdapterBuilder, ModelAdapter};
12use oar_ocr_core::domain::adapters::{
13 DocumentOrientationAdapter, DocumentOrientationAdapterBuilder, FormulaRecognitionAdapter,
14 LayoutDetectionAdapter, LayoutDetectionAdapterBuilder, PPFormulaNetAdapterBuilder,
15 SLANetWiredAdapterBuilder, SLANetWirelessAdapterBuilder, SealTextDetectionAdapter,
16 SealTextDetectionAdapterBuilder, TableCellDetectionAdapter, TableCellDetectionAdapterBuilder,
17 TableClassificationAdapter, TableClassificationAdapterBuilder,
18 TableStructureRecognitionAdapter, TextDetectionAdapter, TextDetectionAdapterBuilder,
19 TextLineOrientationAdapter, TextLineOrientationAdapterBuilder, TextRecognitionAdapter,
20 TextRecognitionAdapterBuilder, UVDocRectifierAdapter, UVDocRectifierAdapterBuilder,
21 UniMERNetAdapterBuilder,
22};
23use oar_ocr_core::domain::structure::{StructureResult, TableResult};
24use oar_ocr_core::domain::tasks::{
25 FormulaRecognitionConfig, LayoutDetectionConfig, TableCellDetectionConfig,
26 TableClassificationConfig, TableStructureRecognitionConfig, TextDetectionConfig,
27 TextRecognitionConfig,
28};
29use std::path::PathBuf;
30use std::sync::Arc;
31
32const LAYOUT_OVERLAP_IOU_THRESHOLD: f32 = 0.5;
34
35const CELL_OVERLAP_IOU_THRESHOLD: f32 = 0.5;
37
38const REGION_MEMBERSHIP_IOA_THRESHOLD: f32 = 0.1;
41
42const TEXT_BOX_SPLIT_IOA_THRESHOLD: f32 = 0.3;
45
46#[derive(Debug)]
48struct StructurePipeline {
49 document_orientation_adapter: Option<DocumentOrientationAdapter>,
51 rectification_adapter: Option<UVDocRectifierAdapter>,
52
53 layout_detection_adapter: LayoutDetectionAdapter,
55
56 region_detection_adapter: Option<LayoutDetectionAdapter>,
58
59 table_classification_adapter: Option<TableClassificationAdapter>,
61 table_orientation_adapter: Option<DocumentOrientationAdapter>, table_cell_detection_adapter: Option<TableCellDetectionAdapter>,
63 table_structure_recognition_adapter: Option<TableStructureRecognitionAdapter>,
64 wired_table_structure_adapter: Option<TableStructureRecognitionAdapter>,
66 wireless_table_structure_adapter: Option<TableStructureRecognitionAdapter>,
67 wired_table_cell_adapter: Option<TableCellDetectionAdapter>,
68 wireless_table_cell_adapter: Option<TableCellDetectionAdapter>,
69 use_e2e_wired_table_rec: bool,
71 use_e2e_wireless_table_rec: bool,
72 use_wired_table_cells_trans_to_html: bool,
74 use_wireless_table_cells_trans_to_html: bool,
75
76 formula_recognition_adapter: Option<FormulaRecognitionAdapter>,
77
78 seal_text_detection_adapter: Option<SealTextDetectionAdapter>,
79
80 text_detection_adapter: Option<TextDetectionAdapter>,
82 text_line_orientation_adapter: Option<TextLineOrientationAdapter>,
83 text_recognition_adapter: Option<TextRecognitionAdapter>,
84
85 region_batch_size: Option<usize>,
87}
88
89#[derive(Debug, Clone)]
123pub struct OARStructureBuilder {
124 layout_detection_model: PathBuf,
126 layout_model_name: Option<String>,
127
128 document_orientation_model: Option<PathBuf>,
130 document_rectification_model: Option<PathBuf>,
131
132 region_detection_model: Option<PathBuf>,
134
135 table_classification_model: Option<PathBuf>,
137 table_orientation_model: Option<PathBuf>, table_cell_detection_model: Option<PathBuf>,
139 table_cell_detection_type: Option<String>, table_structure_recognition_model: Option<PathBuf>,
141 table_structure_recognition_type: Option<String>, table_structure_dict_path: Option<PathBuf>,
143
144 wired_table_structure_model: Option<PathBuf>,
145 wireless_table_structure_model: Option<PathBuf>,
146 wired_table_cell_model: Option<PathBuf>,
147 wireless_table_cell_model: Option<PathBuf>,
148 use_e2e_wired_table_rec: bool,
151 use_e2e_wireless_table_rec: bool,
152 use_wired_table_cells_trans_to_html: bool,
155 use_wireless_table_cells_trans_to_html: bool,
156
157 formula_recognition_model: Option<PathBuf>,
159 formula_recognition_type: Option<String>, formula_tokenizer_path: Option<PathBuf>,
161
162 seal_text_detection_model: Option<PathBuf>,
164
165 text_detection_model: Option<PathBuf>,
167 text_line_orientation_model: Option<PathBuf>,
168 text_recognition_model: Option<PathBuf>,
169 character_dict_path: Option<PathBuf>,
170
171 region_model_name: Option<String>,
173 wired_table_structure_model_name: Option<String>,
174 wireless_table_structure_model_name: Option<String>,
175 wired_table_cell_model_name: Option<String>,
176 wireless_table_cell_model_name: Option<String>,
177 text_detection_model_name: Option<String>,
178 text_recognition_model_name: Option<String>,
179
180 ort_session_config: Option<OrtSessionConfig>,
182 layout_detection_config: Option<LayoutDetectionConfig>,
183 table_classification_config: Option<TableClassificationConfig>,
184 table_cell_detection_config: Option<TableCellDetectionConfig>,
185 table_structure_recognition_config: Option<TableStructureRecognitionConfig>,
186 formula_recognition_config: Option<FormulaRecognitionConfig>,
187 text_detection_config: Option<TextDetectionConfig>,
188 text_recognition_config: Option<TextRecognitionConfig>,
189
190 image_batch_size: Option<usize>,
192 region_batch_size: Option<usize>,
193}
194
195impl OARStructureBuilder {
196 pub fn new(layout_detection_model: impl Into<PathBuf>) -> Self {
202 Self {
203 layout_detection_model: layout_detection_model.into(),
204 layout_model_name: None,
205 document_orientation_model: None,
206 document_rectification_model: None,
207 region_detection_model: None,
208 table_classification_model: None,
209 table_orientation_model: None,
210 table_cell_detection_model: None,
211 table_cell_detection_type: None,
212 table_structure_recognition_model: None,
213 table_structure_recognition_type: None,
214 table_structure_dict_path: None,
215 wired_table_structure_model: None,
216 wireless_table_structure_model: None,
217 wired_table_cell_model: None,
218 wireless_table_cell_model: None,
219 use_e2e_wired_table_rec: false,
221 use_e2e_wireless_table_rec: true,
222 use_wired_table_cells_trans_to_html: false,
223 use_wireless_table_cells_trans_to_html: false,
224 formula_recognition_model: None,
225 formula_recognition_type: None,
226 formula_tokenizer_path: None,
227 seal_text_detection_model: None,
228 text_detection_model: None,
229 text_line_orientation_model: None,
230 text_recognition_model: None,
231 character_dict_path: None,
232 region_model_name: None,
233 wired_table_structure_model_name: None,
234 wireless_table_structure_model_name: None,
235 wired_table_cell_model_name: None,
236 wireless_table_cell_model_name: None,
237 text_detection_model_name: None,
238 text_recognition_model_name: None,
239 ort_session_config: None,
240 layout_detection_config: None,
241 table_classification_config: None,
242 table_cell_detection_config: None,
243 table_structure_recognition_config: None,
244 formula_recognition_config: None,
245 text_detection_config: None,
246 text_recognition_config: None,
247 image_batch_size: None,
248 region_batch_size: None,
249 }
250 }
251
252 pub fn ort_session(mut self, config: OrtSessionConfig) -> Self {
256 self.ort_session_config = Some(config);
257 self
258 }
259
260 pub fn layout_detection_config(mut self, config: LayoutDetectionConfig) -> Self {
262 self.layout_detection_config = Some(config);
263 self
264 }
265
266 pub fn layout_model_name(mut self, name: impl Into<String>) -> Self {
276 self.layout_model_name = Some(name.into());
277 self
278 }
279
280 pub fn region_model_name(mut self, name: impl Into<String>) -> Self {
285 self.region_model_name = Some(name.into());
286 self
287 }
288
289 pub fn wired_table_structure_model_name(mut self, name: impl Into<String>) -> Self {
293 self.wired_table_structure_model_name = Some(name.into());
294 self
295 }
296
297 pub fn wireless_table_structure_model_name(mut self, name: impl Into<String>) -> Self {
301 self.wireless_table_structure_model_name = Some(name.into());
302 self
303 }
304
305 pub fn wired_table_cell_model_name(mut self, name: impl Into<String>) -> Self {
309 self.wired_table_cell_model_name = Some(name.into());
310 self
311 }
312
313 pub fn wireless_table_cell_model_name(mut self, name: impl Into<String>) -> Self {
317 self.wireless_table_cell_model_name = Some(name.into());
318 self
319 }
320
321 pub fn text_detection_model_name(mut self, name: impl Into<String>) -> Self {
325 self.text_detection_model_name = Some(name.into());
326 self
327 }
328
329 pub fn text_recognition_model_name(mut self, name: impl Into<String>) -> Self {
333 self.text_recognition_model_name = Some(name.into());
334 self
335 }
336
337 pub fn image_batch_size(mut self, size: usize) -> Self {
341 self.image_batch_size = Some(size);
342 self
343 }
344
345 pub fn region_batch_size(mut self, size: usize) -> Self {
350 self.region_batch_size = Some(size);
351 self
352 }
353
354 pub fn with_document_orientation(mut self, model_path: impl Into<PathBuf>) -> Self {
359 self.document_orientation_model = Some(model_path.into());
360 self
361 }
362
363 pub fn with_document_rectification(mut self, model_path: impl Into<PathBuf>) -> Self {
368 self.document_rectification_model = Some(model_path.into());
369 self
370 }
371
372 pub fn with_region_detection(mut self, model_path: impl Into<PathBuf>) -> Self {
385 self.region_detection_model = Some(model_path.into());
386 self
387 }
388
389 pub fn with_seal_text_detection(mut self, model_path: impl Into<PathBuf>) -> Self {
394 self.seal_text_detection_model = Some(model_path.into());
395 self
396 }
397
398 pub fn with_table_classification(mut self, model_path: impl Into<PathBuf>) -> Self {
402 self.table_classification_model = Some(model_path.into());
403 self
404 }
405
406 pub fn table_classification_config(mut self, config: TableClassificationConfig) -> Self {
408 self.table_classification_config = Some(config);
409 self
410 }
411
412 pub fn with_table_orientation(mut self, model_path: impl Into<PathBuf>) -> Self {
422 self.table_orientation_model = Some(model_path.into());
423 self
424 }
425
426 pub fn use_e2e_wired_table_rec(mut self, enabled: bool) -> Self {
434 self.use_e2e_wired_table_rec = enabled;
435 self
436 }
437
438 pub fn use_e2e_wireless_table_rec(mut self, enabled: bool) -> Self {
446 self.use_e2e_wireless_table_rec = enabled;
447 self
448 }
449
450 pub fn use_wired_table_cells_trans_to_html(mut self, enabled: bool) -> Self {
455 self.use_wired_table_cells_trans_to_html = enabled;
456 self
457 }
458
459 pub fn use_wireless_table_cells_trans_to_html(mut self, enabled: bool) -> Self {
464 self.use_wireless_table_cells_trans_to_html = enabled;
465 self
466 }
467
468 pub fn with_table_cell_detection(
475 mut self,
476 model_path: impl Into<PathBuf>,
477 cell_type: impl Into<String>,
478 ) -> Self {
479 self.table_cell_detection_model = Some(model_path.into());
480 self.table_cell_detection_type = Some(cell_type.into());
481 self
482 }
483
484 pub fn table_cell_detection_config(mut self, config: TableCellDetectionConfig) -> Self {
486 self.table_cell_detection_config = Some(config);
487 self
488 }
489
490 pub fn with_table_structure_recognition(
499 mut self,
500 model_path: impl Into<PathBuf>,
501 table_type: impl Into<String>,
502 ) -> Self {
503 self.table_structure_recognition_model = Some(model_path.into());
504 self.table_structure_recognition_type = Some(table_type.into());
505 self
506 }
507
508 pub fn table_structure_dict_path(mut self, path: impl Into<PathBuf>) -> Self {
515 self.table_structure_dict_path = Some(path.into());
516 self
517 }
518
519 pub fn table_structure_recognition_config(
521 mut self,
522 config: TableStructureRecognitionConfig,
523 ) -> Self {
524 self.table_structure_recognition_config = Some(config);
525 self
526 }
527
528 pub fn with_wired_table_structure(mut self, model_path: impl Into<PathBuf>) -> Self {
533 self.wired_table_structure_model = Some(model_path.into());
534 self
535 }
536
537 pub fn with_wireless_table_structure(mut self, model_path: impl Into<PathBuf>) -> Self {
542 self.wireless_table_structure_model = Some(model_path.into());
543 self
544 }
545
546 pub fn with_wired_table_cell_detection(mut self, model_path: impl Into<PathBuf>) -> Self {
551 self.wired_table_cell_model = Some(model_path.into());
552 self
553 }
554
555 pub fn with_wireless_table_cell_detection(mut self, model_path: impl Into<PathBuf>) -> Self {
560 self.wireless_table_cell_model = Some(model_path.into());
561 self
562 }
563
564 pub fn with_formula_recognition(
574 mut self,
575 model_path: impl Into<PathBuf>,
576 tokenizer_path: impl Into<PathBuf>,
577 model_type: impl Into<String>,
578 ) -> Self {
579 self.formula_recognition_model = Some(model_path.into());
580 self.formula_tokenizer_path = Some(tokenizer_path.into());
581 self.formula_recognition_type = Some(model_type.into());
582 self
583 }
584
585 pub fn formula_recognition_config(mut self, config: FormulaRecognitionConfig) -> Self {
587 self.formula_recognition_config = Some(config);
588 self
589 }
590
591 pub fn with_ocr(
599 mut self,
600 text_detection_model: impl Into<PathBuf>,
601 text_recognition_model: impl Into<PathBuf>,
602 character_dict_path: impl Into<PathBuf>,
603 ) -> Self {
604 self.text_detection_model = Some(text_detection_model.into());
605 self.text_recognition_model = Some(text_recognition_model.into());
606 self.character_dict_path = Some(character_dict_path.into());
607 self
608 }
609
610 pub fn with_text_line_orientation(mut self, model_path: impl Into<PathBuf>) -> Self {
621 self.text_line_orientation_model = Some(model_path.into());
622 self
623 }
624
625 pub fn text_detection_config(mut self, config: TextDetectionConfig) -> Self {
627 self.text_detection_config = Some(config);
628 self
629 }
630
631 pub fn text_recognition_config(mut self, config: TextRecognitionConfig) -> Self {
633 self.text_recognition_config = Some(config);
634 self
635 }
636
637 pub fn build(self) -> Result<OARStructure, OCRError> {
641 let char_dict = if let Some(ref dict_path) = self.character_dict_path {
643 Some(
644 std::fs::read_to_string(dict_path).map_err(|e| OCRError::InvalidInput {
645 message: format!(
646 "Failed to read character dictionary from '{}': {}",
647 dict_path.display(),
648 e
649 ),
650 })?,
651 )
652 } else {
653 None
654 };
655
656 let document_orientation_adapter = build_optional_adapter(
658 self.document_orientation_model.as_ref(),
659 self.ort_session_config.as_ref(),
660 DocumentOrientationAdapterBuilder::new,
661 )?;
662
663 let rectification_adapter = build_optional_adapter(
665 self.document_rectification_model.as_ref(),
666 self.ort_session_config.as_ref(),
667 UVDocRectifierAdapterBuilder::new,
668 )?;
669
670 let mut layout_builder = LayoutDetectionAdapterBuilder::new();
672
673 let layout_model_config = if let Some(name) = &self.layout_model_name {
675 use oar_ocr_core::domain::adapters::LayoutModelConfig;
676 match name.as_str() {
677 "picodet_layout_1x" => LayoutModelConfig::picodet_layout_1x(),
678 "picodet_layout_1x_table" => LayoutModelConfig::picodet_layout_1x_table(),
679 "picodet_s_layout_3cls" => LayoutModelConfig::picodet_s_layout_3cls(),
680 "picodet_l_layout_3cls" => LayoutModelConfig::picodet_l_layout_3cls(),
681 "picodet_s_layout_17cls" => LayoutModelConfig::picodet_s_layout_17cls(),
682 "picodet_l_layout_17cls" => LayoutModelConfig::picodet_l_layout_17cls(),
683 "rt-detr-h_layout_3cls" => LayoutModelConfig::rtdetr_h_layout_3cls(),
684 "rt-detr-h_layout_17cls" => LayoutModelConfig::rtdetr_h_layout_17cls(),
685 "pp-docblocklayout" => LayoutModelConfig::pp_docblocklayout(),
686 "pp-doclayout-s" => LayoutModelConfig::pp_doclayout_s(),
687 "pp-doclayout-m" => LayoutModelConfig::pp_doclayout_m(),
688 "pp-doclayout-l" => LayoutModelConfig::pp_doclayout_l(),
689 "pp-doclayout_plus-l" => LayoutModelConfig::pp_doclayout_plus_l(),
690 _ => LayoutModelConfig::pp_doclayout_plus_l(),
691 }
692 } else {
693 crate::domain::adapters::LayoutModelConfig::pp_doclayout_plus_l()
695 };
696
697 layout_builder = layout_builder.model_config(layout_model_config);
698
699 let effective_layout_cfg = self
701 .layout_detection_config
702 .clone()
703 .unwrap_or_else(LayoutDetectionConfig::with_pp_structurev3_defaults);
704 layout_builder = layout_builder.with_config(effective_layout_cfg);
705
706 if let Some(ref ort_config) = self.ort_session_config {
707 layout_builder = layout_builder.with_ort_config(ort_config.clone());
708 }
709
710 let layout_detection_adapter = layout_builder.build(&self.layout_detection_model)?;
711
712 let region_detection_adapter = if let Some(ref model_path) = self.region_detection_model {
714 use oar_ocr_core::domain::adapters::LayoutModelConfig;
715 let mut region_builder = LayoutDetectionAdapterBuilder::new();
716
717 let region_model_config = if let Some(ref name) = self.region_model_name {
719 match name.to_lowercase().replace("-", "_").as_str() {
720 "pp_docblocklayout" => LayoutModelConfig::pp_docblocklayout(),
721 _ => LayoutModelConfig::pp_docblocklayout(),
722 }
723 } else {
724 LayoutModelConfig::pp_docblocklayout()
725 };
726 region_builder = region_builder.model_config(region_model_config);
727
728 let mut region_cfg = LayoutDetectionConfig::default();
730 let mut merge_modes = std::collections::HashMap::new();
731 merge_modes.insert(
732 "region".to_string(),
733 crate::domain::tasks::layout_detection::MergeBboxMode::Small,
734 );
735 region_cfg.class_merge_modes = Some(merge_modes);
736 region_builder = region_builder.with_config(region_cfg);
737
738 if let Some(ref ort_config) = self.ort_session_config {
739 region_builder = region_builder.with_ort_config(ort_config.clone());
740 }
741
742 Some(region_builder.build(model_path)?)
743 } else {
744 None
745 };
746
747 let table_classification_adapter =
749 if let Some(ref model_path) = self.table_classification_model {
750 let mut builder = TableClassificationAdapterBuilder::new();
751
752 if let Some(ref config) = self.table_classification_config {
753 builder = builder.with_config(config.clone());
754 }
755
756 if let Some(ref ort_config) = self.ort_session_config {
757 builder = builder.with_ort_config(ort_config.clone());
758 }
759
760 Some(builder.build(model_path)?)
761 } else {
762 None
763 };
764
765 let table_orientation_adapter = build_optional_adapter(
768 self.table_orientation_model.as_ref(),
769 self.ort_session_config.as_ref(),
770 DocumentOrientationAdapterBuilder::new,
771 )?;
772
773 let table_cell_detection_adapter = if let Some(ref model_path) =
775 self.table_cell_detection_model
776 {
777 let cell_type = self.table_cell_detection_type.as_deref().unwrap_or("wired");
778
779 use oar_ocr_core::domain::adapters::table_cell_detection_adapter::TableCellModelConfig;
780
781 let model_config = match cell_type {
782 "wired" => TableCellModelConfig::rtdetr_l_wired_table_cell_det(),
783 "wireless" => TableCellModelConfig::rtdetr_l_wireless_table_cell_det(),
784 _ => {
785 return Err(OCRError::config_error_detailed(
786 "table_cell_detection",
787 format!(
788 "Invalid cell type '{}': must be 'wired' or 'wireless'",
789 cell_type
790 ),
791 ));
792 }
793 };
794
795 let mut builder = TableCellDetectionAdapterBuilder::new().model_config(model_config);
796
797 if let Some(ref config) = self.table_cell_detection_config {
798 builder = builder.with_config(config.clone());
799 }
800
801 if let Some(ref ort_config) = self.ort_session_config {
802 builder = builder.with_ort_config(ort_config.clone());
803 }
804
805 Some(builder.build(model_path)?)
806 } else {
807 None
808 };
809
810 let table_structure_recognition_adapter = if let Some(ref model_path) =
812 self.table_structure_recognition_model
813 {
814 let table_type = self
815 .table_structure_recognition_type
816 .as_deref()
817 .unwrap_or("wired");
818 let dict_path = self
819 .table_structure_dict_path
820 .clone()
821 .ok_or_else(|| {
822 OCRError::config_error_detailed(
823 "table_structure_recognition",
824 "Dictionary path is required. Call table_structure_dict_path() when enabling table structure recognition.".to_string(),
825 )
826 })?;
827
828 let adapter: TableStructureRecognitionAdapter = match table_type {
829 "wired" => {
830 let mut builder = SLANetWiredAdapterBuilder::new().dict_path(dict_path.clone());
831
832 if let Some(ref config) = self.table_structure_recognition_config {
833 builder = builder.with_config(config.clone());
834 }
835
836 if let Some(ref ort_config) = self.ort_session_config {
837 builder = builder.with_ort_config(ort_config.clone());
838 }
839
840 builder.build(model_path)?
841 }
842 "wireless" => {
843 let mut builder =
844 SLANetWirelessAdapterBuilder::new().dict_path(dict_path.clone());
845
846 if let Some(ref config) = self.table_structure_recognition_config {
847 builder = builder.with_config(config.clone());
848 }
849
850 if let Some(ref ort_config) = self.ort_session_config {
851 builder = builder.with_ort_config(ort_config.clone());
852 }
853
854 builder.build(model_path)?
855 }
856 _ => {
857 return Err(OCRError::config_error_detailed(
858 "table_structure_recognition",
859 format!(
860 "Invalid table type '{}': must be 'wired' or 'wireless'",
861 table_type
862 ),
863 ));
864 }
865 };
866
867 Some(adapter)
868 } else {
869 None
870 };
871
872 let wired_table_structure_adapter = if let Some(ref model_path) =
874 self.wired_table_structure_model
875 {
876 let dict_path = self.table_structure_dict_path.clone().ok_or_else(|| {
877 OCRError::config_error_detailed(
878 "wired_table_structure",
879 "Dictionary path is required. Call table_structure_dict_path() when enabling table structure recognition.".to_string(),
880 )
881 })?;
882
883 let mut builder = SLANetWiredAdapterBuilder::new().dict_path(dict_path);
884
885 if let Some(ref config) = self.table_structure_recognition_config {
886 builder = builder.with_config(config.clone());
887 }
888
889 if let Some(ref ort_config) = self.ort_session_config {
890 builder = builder.with_ort_config(ort_config.clone());
891 }
892
893 Some(builder.build(model_path)?)
894 } else {
895 None
896 };
897
898 let wireless_table_structure_adapter = if let Some(ref model_path) =
899 self.wireless_table_structure_model
900 {
901 let dict_path = self.table_structure_dict_path.clone().ok_or_else(|| {
902 OCRError::config_error_detailed(
903 "wireless_table_structure",
904 "Dictionary path is required. Call table_structure_dict_path() when enabling table structure recognition.".to_string(),
905 )
906 })?;
907
908 let mut builder = SLANetWirelessAdapterBuilder::new().dict_path(dict_path);
909
910 if let Some(ref config) = self.table_structure_recognition_config {
911 builder = builder.with_config(config.clone());
912 }
913
914 if let Some(ref ort_config) = self.ort_session_config {
915 builder = builder.with_ort_config(ort_config.clone());
916 }
917
918 Some(builder.build(model_path)?)
919 } else {
920 None
921 };
922
923 let wired_table_cell_adapter = if let Some(ref model_path) = self.wired_table_cell_model {
925 use oar_ocr_core::domain::adapters::table_cell_detection_adapter::TableCellModelConfig;
926
927 let model_config = TableCellModelConfig::rtdetr_l_wired_table_cell_det();
928 let mut builder = TableCellDetectionAdapterBuilder::new().model_config(model_config);
929
930 if let Some(ref config) = self.table_cell_detection_config {
931 builder = builder.with_config(config.clone());
932 }
933
934 if let Some(ref ort_config) = self.ort_session_config {
935 builder = builder.with_ort_config(ort_config.clone());
936 }
937
938 Some(builder.build(model_path)?)
939 } else {
940 None
941 };
942
943 let wireless_table_cell_adapter = if let Some(ref model_path) =
944 self.wireless_table_cell_model
945 {
946 use oar_ocr_core::domain::adapters::table_cell_detection_adapter::TableCellModelConfig;
947
948 let model_config = TableCellModelConfig::rtdetr_l_wireless_table_cell_det();
949 let mut builder = TableCellDetectionAdapterBuilder::new().model_config(model_config);
950
951 if let Some(ref config) = self.table_cell_detection_config {
952 builder = builder.with_config(config.clone());
953 }
954
955 if let Some(ref ort_config) = self.ort_session_config {
956 builder = builder.with_ort_config(ort_config.clone());
957 }
958
959 Some(builder.build(model_path)?)
960 } else {
961 None
962 };
963
964 let formula_recognition_adapter = if let Some(ref model_path) =
966 self.formula_recognition_model
967 {
968 let tokenizer_path = self.formula_tokenizer_path.as_ref().ok_or_else(|| {
969 OCRError::config_error_detailed(
970 "formula_recognition",
971 "Tokenizer path is required for formula recognition".to_string(),
972 )
973 })?;
974
975 let model_type = self.formula_recognition_type.as_deref().ok_or_else(|| {
976 OCRError::config_error_detailed(
977 "formula_recognition",
978 "Model type is required (must be 'pp_formulanet' or 'unimernet')".to_string(),
979 )
980 })?;
981
982 let adapter: FormulaRecognitionAdapter = match model_type.to_lowercase().as_str() {
983 "pp_formulanet" | "pp-formulanet" => {
984 let mut builder = PPFormulaNetAdapterBuilder::new();
985
986 builder = builder.tokenizer_path(tokenizer_path);
987
988 if let Some(ref config) = self.formula_recognition_config {
991 builder = builder.task_config(config.clone());
992 }
993
994 if let Some(ref ort_config) = self.ort_session_config {
995 builder = builder.with_ort_config(ort_config.clone());
996 }
997
998 builder.build(model_path)?
999 }
1000 "unimernet" => {
1001 let mut builder = UniMERNetAdapterBuilder::new();
1002
1003 builder = builder.tokenizer_path(tokenizer_path);
1004
1005 if let Some(ref config) = self.formula_recognition_config {
1008 builder = builder.task_config(config.clone());
1009 }
1010
1011 if let Some(ref ort_config) = self.ort_session_config {
1012 builder = builder.with_ort_config(ort_config.clone());
1013 }
1014
1015 builder.build(model_path)?
1016 }
1017 _ => {
1018 return Err(OCRError::config_error_detailed(
1019 "formula_recognition",
1020 format!(
1021 "Invalid model type '{}': must be 'pp_formulanet' or 'unimernet'",
1022 model_type
1023 ),
1024 ));
1025 }
1026 };
1027
1028 Some(adapter)
1029 } else {
1030 None
1031 };
1032
1033 let seal_text_detection_adapter =
1035 if let Some(ref model_path) = self.seal_text_detection_model {
1036 let mut builder = SealTextDetectionAdapterBuilder::new();
1037
1038 if let Some(ref ort_config) = self.ort_session_config {
1039 builder = builder.with_ort_config(ort_config.clone());
1040 }
1041
1042 Some(builder.build(model_path)?)
1043 } else {
1044 None
1045 };
1046
1047 let text_detection_adapter = if let Some(ref model_path) = self.text_detection_model {
1056 let mut builder = TextDetectionAdapterBuilder::new();
1057
1058 let mut effective_cfg = self.text_detection_config.clone().unwrap_or_default();
1061
1062 let has_table_pipeline = self.table_classification_model.is_some()
1065 || self.table_structure_recognition_model.is_some()
1066 || self.wired_table_structure_model.is_some()
1067 || self.wireless_table_structure_model.is_some()
1068 || self.table_cell_detection_model.is_some()
1069 || self.wired_table_cell_model.is_some()
1070 || self.wireless_table_cell_model.is_some();
1071 if self.text_detection_config.is_none() && has_table_pipeline {
1072 effective_cfg.box_threshold = 0.4;
1073 }
1074
1075 if effective_cfg.limit_side_len.is_none() {
1076 effective_cfg.limit_side_len = Some(736);
1077 }
1078 if effective_cfg.limit_type.is_none() {
1079 effective_cfg.limit_type = Some(crate::processors::LimitType::Min);
1080 }
1081 if effective_cfg.max_side_len.is_none() {
1082 effective_cfg.max_side_len = Some(4000);
1083 }
1084 builder = builder.with_config(effective_cfg);
1085
1086 if let Some(ref ort_config) = self.ort_session_config {
1087 builder = builder.with_ort_config(ort_config.clone());
1088 }
1089
1090 Some(builder.build(model_path)?)
1091 } else {
1092 None
1093 };
1094
1095 let text_line_orientation_adapter =
1097 if let Some(ref model_path) = self.text_line_orientation_model {
1098 let mut builder = TextLineOrientationAdapterBuilder::new();
1099
1100 if let Some(ref ort_config) = self.ort_session_config {
1101 builder = builder.with_ort_config(ort_config.clone());
1102 }
1103
1104 Some(builder.build(model_path)?)
1105 } else {
1106 None
1107 };
1108
1109 let text_recognition_adapter = if let Some(ref model_path) = self.text_recognition_model {
1111 let dict = char_dict.ok_or_else(|| OCRError::InvalidInput {
1112 message: "Character dictionary is required for text recognition".to_string(),
1113 })?;
1114
1115 let char_vec: Vec<String> = dict.lines().map(|s| s.to_string()).collect();
1117
1118 let mut builder = TextRecognitionAdapterBuilder::new().character_dict(char_vec);
1119
1120 if let Some(ref config) = self.text_recognition_config {
1123 builder = builder.with_config(config.clone());
1124 }
1125
1126 if let Some(ref ort_config) = self.ort_session_config {
1127 builder = builder.with_ort_config(ort_config.clone());
1128 }
1129
1130 Some(builder.build(model_path)?)
1131 } else {
1132 None
1133 };
1134
1135 let pipeline = StructurePipeline {
1136 document_orientation_adapter,
1137 rectification_adapter,
1138 layout_detection_adapter,
1139 region_detection_adapter,
1140 table_classification_adapter,
1141 table_orientation_adapter,
1142 table_cell_detection_adapter,
1143 table_structure_recognition_adapter,
1144 wired_table_structure_adapter,
1145 wireless_table_structure_adapter,
1146 wired_table_cell_adapter,
1147 wireless_table_cell_adapter,
1148 use_e2e_wired_table_rec: self.use_e2e_wired_table_rec,
1149 use_e2e_wireless_table_rec: self.use_e2e_wireless_table_rec,
1150 use_wired_table_cells_trans_to_html: self.use_wired_table_cells_trans_to_html,
1151 use_wireless_table_cells_trans_to_html: self.use_wireless_table_cells_trans_to_html,
1152 formula_recognition_adapter,
1153 seal_text_detection_adapter,
1154 text_detection_adapter,
1155 text_line_orientation_adapter,
1156 text_recognition_adapter,
1157 region_batch_size: self.region_batch_size,
1158 };
1159
1160 Ok(OARStructure { pipeline })
1161 }
1162}
1163
1164#[derive(Debug)]
1168pub struct OARStructure {
1169 pipeline: StructurePipeline,
1170}
1171
1172struct PreparedPage {
1175 current_image: std::sync::Arc<image::RgbImage>,
1176 orientation_angle: Option<f32>,
1177 rectified_img: Option<std::sync::Arc<image::RgbImage>>,
1178 rotation: Option<crate::oarocr::preprocess::OrientationCorrection>,
1179 layout_elements: Vec<crate::domain::structure::LayoutElement>,
1180 detected_region_blocks: Option<Vec<crate::domain::structure::RegionBlock>>,
1181}
1182
1183impl OARStructure {
1184 fn refine_overall_ocr_with_layout(
1195 text_regions: &mut Vec<crate::oarocr::TextRegion>,
1196 layout_elements: &[crate::domain::structure::LayoutElement],
1197 region_blocks: Option<&[crate::domain::structure::RegionBlock]>,
1198 page_image: &image::RgbImage,
1199 text_recognition_adapter: &TextRecognitionAdapter,
1200 region_batch_size: usize,
1201 ) -> Result<(), OCRError> {
1202 use oar_ocr_core::core::traits::task::ImageTaskInput;
1203 use oar_ocr_core::domain::structure::LayoutElementType;
1204 use oar_ocr_core::processors::BoundingBox;
1205 use oar_ocr_core::utils::BBoxCrop;
1206
1207 if text_regions.is_empty() || layout_elements.is_empty() {
1208 return Ok(());
1209 }
1210
1211 fn aabb_intersection(b1: &BoundingBox, b2: &BoundingBox) -> Option<BoundingBox> {
1212 let x1 = b1.x_min().max(b2.x_min());
1213 let y1 = b1.y_min().max(b2.y_min());
1214 let x2 = b1.x_max().min(b2.x_max());
1215 let y2 = b1.y_max().min(b2.y_max());
1216 if x2 - x1 <= 1.0 || y2 - y1 <= 1.0 {
1217 None
1218 } else {
1219 Some(BoundingBox::from_coords(x1, y1, x2, y2))
1220 }
1221 }
1222
1223 let is_excluded_layout = |t: LayoutElementType| {
1225 matches!(
1226 t,
1227 LayoutElementType::Formula
1228 | LayoutElementType::FormulaNumber
1229 | LayoutElementType::Table
1230 | LayoutElementType::Seal
1231 )
1232 };
1233
1234 let min_pixels = 3.0;
1238 let mut matched_ocr: Vec<Vec<usize>> = vec![Vec::new(); text_regions.len()];
1239 for (ocr_idx, region) in text_regions.iter().enumerate() {
1240 for (layout_idx, elem) in layout_elements.iter().enumerate() {
1241 if is_excluded_layout(elem.element_type) {
1242 continue;
1243 }
1244 let inter_x_min = region.bounding_box.x_min().max(elem.bbox.x_min());
1245 let inter_y_min = region.bounding_box.y_min().max(elem.bbox.y_min());
1246 let inter_x_max = region.bounding_box.x_max().min(elem.bbox.x_max());
1247 let inter_y_max = region.bounding_box.y_max().min(elem.bbox.y_max());
1248 if inter_x_max - inter_x_min > min_pixels && inter_y_max - inter_y_min > min_pixels
1249 {
1250 matched_ocr[ocr_idx].push(layout_idx);
1251 }
1252 }
1253 }
1254
1255 let mut appended_regions: Vec<crate::oarocr::TextRegion> = Vec::new();
1257 let original_ocr_len = text_regions.len();
1258 let mut multi_layout_ocr_count = 0usize;
1259 let mut multi_layout_crop_count = 0usize;
1260
1261 for ocr_idx in 0..original_ocr_len {
1262 let layout_ids = matched_ocr[ocr_idx].clone();
1263 if layout_ids.len() <= 1 {
1264 continue;
1265 }
1266 multi_layout_ocr_count += 1;
1267
1268 let ocr_box = text_regions[ocr_idx].bounding_box.clone();
1269
1270 let mut crops: Vec<image::RgbImage> = Vec::new();
1271 let mut crop_boxes: Vec<(BoundingBox, bool)> = Vec::new(); for (j, layout_idx) in layout_ids.iter().enumerate() {
1274 let layout_box = &layout_elements[*layout_idx].bbox;
1275 let Some(crop_box) = aabb_intersection(&ocr_box, layout_box) else {
1276 continue;
1277 };
1278
1279 for (other_idx, other_region) in text_regions.iter_mut().enumerate() {
1281 if other_idx == ocr_idx {
1282 continue;
1283 }
1284 if other_region.bounding_box.iou(&crop_box) > 0.8 {
1285 other_region.text = None;
1286 }
1287 }
1288
1289 if let Ok(crop_img) = BBoxCrop::crop_bounding_box(page_image, &crop_box) {
1290 crops.push(crop_img);
1291 crop_boxes.push((crop_box, j == 0));
1292 }
1293 }
1294 multi_layout_crop_count += crop_boxes.len();
1295
1296 if crops.is_empty() {
1297 continue;
1298 }
1299
1300 let mut rec_texts: Vec<String> = Vec::with_capacity(crops.len());
1302 let mut rec_scores: Vec<f32> = Vec::with_capacity(crops.len());
1303
1304 for batch_start in (0..crops.len()).step_by(region_batch_size.max(1)) {
1305 let batch_end = (batch_start + region_batch_size).min(crops.len());
1306 let batch: Vec<_> = crops[batch_start..batch_end].to_vec();
1307 let rec_input = ImageTaskInput::new(batch);
1308 let rec_result = text_recognition_adapter.execute(rec_input, None)?;
1309 rec_texts.extend(rec_result.texts);
1310 rec_scores.extend(rec_result.scores);
1311 }
1312
1313 for ((crop_box, is_first), (text, score)) in crop_boxes
1314 .into_iter()
1315 .zip(rec_texts.into_iter().zip(rec_scores.into_iter()))
1316 {
1317 if text.is_empty() {
1318 continue;
1319 }
1320 if is_first {
1321 text_regions[ocr_idx].bounding_box = crop_box.clone();
1322 text_regions[ocr_idx].dt_poly = Some(crop_box.clone());
1323 text_regions[ocr_idx].rec_poly = Some(crop_box.clone());
1324 text_regions[ocr_idx].text = Some(Arc::from(text));
1325 text_regions[ocr_idx].confidence = Some(score);
1326 } else {
1327 appended_regions.push(crate::oarocr::TextRegion {
1328 bounding_box: crop_box.clone(),
1329 dt_poly: Some(crop_box.clone()),
1330 rec_poly: Some(crop_box),
1331 text: Some(Arc::from(text)),
1332 confidence: Some(score),
1333 orientation_angle: None,
1334 word_boxes: None,
1335 label: None,
1336 });
1337 }
1338 }
1339 }
1340
1341 if !appended_regions.is_empty() {
1342 text_regions.extend(appended_regions);
1343 }
1344
1345 let mut fallback_blocks = 0usize;
1348 for elem in layout_elements.iter() {
1349 if is_excluded_layout(elem.element_type) {
1350 continue;
1351 }
1352 if matches!(
1353 elem.element_type,
1354 LayoutElementType::Image | LayoutElementType::Chart
1355 ) {
1356 continue;
1357 }
1358
1359 let mut has_text = false;
1360 for region in text_regions.iter() {
1361 if !region.text.as_ref().map(|t| !t.is_empty()).unwrap_or(false) {
1362 continue;
1363 }
1364 let inter_x_min = region.bounding_box.x_min().max(elem.bbox.x_min());
1365 let inter_y_min = region.bounding_box.y_min().max(elem.bbox.y_min());
1366 let inter_x_max = region.bounding_box.x_max().min(elem.bbox.x_max());
1367 let inter_y_max = region.bounding_box.y_max().min(elem.bbox.y_max());
1368 if inter_x_max - inter_x_min > min_pixels && inter_y_max - inter_y_min > min_pixels
1369 {
1370 has_text = true;
1371 break;
1372 }
1373 }
1374
1375 if has_text {
1376 continue;
1377 }
1378 fallback_blocks += 1;
1379
1380 if let Ok(crop_img) = BBoxCrop::crop_bounding_box(page_image, &elem.bbox) {
1382 let rec_input = ImageTaskInput::new(vec![crop_img]);
1383 let rec_result = text_recognition_adapter.execute(rec_input, None)?;
1384 if let (Some(text), Some(score)) =
1385 (rec_result.texts.first(), rec_result.scores.first())
1386 && !text.is_empty()
1387 {
1388 let crop_box = elem.bbox.clone();
1389 text_regions.push(crate::oarocr::TextRegion {
1390 bounding_box: crop_box.clone(),
1391 dt_poly: Some(crop_box.clone()),
1392 rec_poly: Some(crop_box),
1393 text: Some(Arc::from(text.as_str())),
1394 confidence: Some(*score),
1395 orientation_angle: None,
1396 word_boxes: None,
1397 label: None,
1398 });
1399 }
1400 }
1401 }
1402
1403 tracing::info!(
1404 "overall OCR refine: multi-layout OCR boxes={}, crops={}, fallback layout blocks={}",
1405 multi_layout_ocr_count,
1406 multi_layout_crop_count,
1407 fallback_blocks
1408 );
1409
1410 let _ = region_blocks;
1413
1414 Ok(())
1415 }
1416
1417 fn split_ocr_bboxes_by_table_cells(
1425 tables: &[TableResult],
1426 text_regions: &mut Vec<crate::oarocr::TextRegion>,
1427 page_image: &image::RgbImage,
1428 text_recognition_adapter: &TextRecognitionAdapter,
1429 ) -> Result<(), OCRError> {
1430 use oar_ocr_core::core::traits::task::ImageTaskInput;
1431 use oar_ocr_core::processors::BoundingBox;
1432
1433 let mut cell_boxes: Vec<[f32; 4]> = Vec::new();
1435 for table in tables {
1436 for cell in &table.cells {
1437 let x1 = cell.bbox.x_min();
1438 let y1 = cell.bbox.y_min();
1439 let x2 = cell.bbox.x_max();
1440 let y2 = cell.bbox.y_max();
1441 if x2 > x1 && y2 > y1 {
1442 cell_boxes.push([x1, y1, x2, y2]);
1443 }
1444 }
1445 }
1446
1447 if cell_boxes.is_empty() || text_regions.is_empty() {
1448 return Ok(());
1449 }
1450
1451 fn overlap_ratio_box_over_cell(box1: &[f32; 4], box2: &[f32; 4]) -> f32 {
1453 let x_left = box1[0].max(box2[0]);
1454 let y_top = box1[1].max(box2[1]);
1455 let x_right = box1[2].min(box2[2]);
1456 let y_bottom = box1[3].min(box2[3]);
1457
1458 if x_right <= x_left || y_bottom <= y_top {
1459 return 0.0;
1460 }
1461
1462 let inter_area = (x_right - x_left) * (y_bottom - y_top);
1463 let cell_area = (box2[2] - box2[0]) * (box2[3] - box2[1]);
1464 if cell_area <= 0.0 {
1465 0.0
1466 } else {
1467 inter_area / cell_area
1468 }
1469 }
1470
1471 fn get_overlapping_cells(
1473 ocr_box: &[f32; 4],
1474 cells: &[[f32; 4]],
1475 threshold: f32,
1476 ) -> Vec<usize> {
1477 let mut overlapping = Vec::new();
1478 for (idx, cell) in cells.iter().enumerate() {
1479 if overlap_ratio_box_over_cell(ocr_box, cell) > threshold {
1480 overlapping.push(idx);
1481 }
1482 }
1483 overlapping.sort_by(|&i, &j| {
1485 cells[i][0]
1486 .partial_cmp(&cells[j][0])
1487 .unwrap_or(std::cmp::Ordering::Equal)
1488 });
1489 overlapping
1490 }
1491
1492 fn split_box_by_cells(
1494 ocr_box: &[f32; 4],
1495 cell_indices: &[usize],
1496 cells: &[[f32; 4]],
1497 ) -> Vec<[f32; 4]> {
1498 if cell_indices.is_empty() {
1499 return vec![*ocr_box];
1500 }
1501
1502 let mut split_boxes: Vec<[f32; 4]> = Vec::new();
1503 let cells_to_split: Vec<[f32; 4]> = cell_indices.iter().map(|&i| cells[i]).collect();
1504
1505 if ocr_box[0] < cells_to_split[0][0] {
1507 split_boxes.push([ocr_box[0], ocr_box[1], cells_to_split[0][0], ocr_box[3]]);
1508 }
1509
1510 for (i, current_cell) in cells_to_split.iter().enumerate() {
1512 split_boxes.push([
1514 ocr_box[0].max(current_cell[0]),
1515 ocr_box[1],
1516 ocr_box[2].min(current_cell[2]),
1517 ocr_box[3],
1518 ]);
1519
1520 if i + 1 < cells_to_split.len() {
1522 let next_cell = cells_to_split[i + 1];
1523 if current_cell[2] < next_cell[0] {
1524 split_boxes.push([current_cell[2], ocr_box[1], next_cell[0], ocr_box[3]]);
1525 }
1526 }
1527 }
1528
1529 let last_cell = cells_to_split[cells_to_split.len() - 1];
1531 if last_cell[2] < ocr_box[2] {
1532 split_boxes.push([last_cell[2], ocr_box[1], ocr_box[2], ocr_box[3]]);
1533 }
1534
1535 let mut unique = Vec::new();
1537 let mut seen = std::collections::HashSet::new();
1538 for b in split_boxes {
1539 let key = (
1540 b[0].to_bits(),
1541 b[1].to_bits(),
1542 b[2].to_bits(),
1543 b[3].to_bits(),
1544 );
1545 if seen.insert(key) {
1546 unique.push(b);
1547 }
1548 }
1549 unique
1550 }
1551
1552 let k_min_cells = 2usize;
1553 let overlap_threshold = CELL_OVERLAP_IOU_THRESHOLD;
1554
1555 let mut new_regions: Vec<crate::oarocr::TextRegion> =
1556 Vec::with_capacity(text_regions.len());
1557
1558 for region in text_regions.iter() {
1559 let ocr_box = [
1560 region.bounding_box.x_min(),
1561 region.bounding_box.y_min(),
1562 region.bounding_box.x_max(),
1563 region.bounding_box.y_max(),
1564 ];
1565
1566 let overlapping_cells = get_overlapping_cells(&ocr_box, &cell_boxes, overlap_threshold);
1567
1568 if overlapping_cells.len() < k_min_cells {
1570 new_regions.push(region.clone());
1571 continue;
1572 }
1573
1574 let split_boxes = split_box_by_cells(&ocr_box, &overlapping_cells, &cell_boxes);
1575
1576 for box_coords in split_boxes {
1577 let img_w = page_image.width() as i32;
1579 let img_h = page_image.height() as i32;
1580
1581 let mut x1 = box_coords[0].floor() as i32;
1582 let mut y1 = box_coords[1].floor() as i32;
1583 let mut x2 = box_coords[2].ceil() as i32;
1584 let mut y2 = box_coords[3].ceil() as i32;
1585
1586 x1 = x1.clamp(0, img_w.saturating_sub(1));
1587 y1 = y1.clamp(0, img_h.saturating_sub(1));
1588 x2 = x2.clamp(0, img_w);
1589 y2 = y2.clamp(0, img_h);
1590
1591 if x2 - x1 <= 1 || y2 - y1 <= 1 {
1592 continue;
1593 }
1594
1595 let crop_w = (x2 - x1) as u32;
1596 let crop_h = (y2 - y1) as u32;
1597 if crop_w <= 1 || crop_h <= 1 {
1598 continue;
1599 }
1600
1601 let x1u = x1 as u32;
1602 let y1u = y1 as u32;
1603 if x1u >= page_image.width() || y1u >= page_image.height() {
1604 continue;
1605 }
1606 let crop_w = crop_w.min(page_image.width() - x1u);
1607 let crop_h = crop_h.min(page_image.height() - y1u);
1608 if crop_w <= 1 || crop_h <= 1 {
1609 continue;
1610 }
1611
1612 let crop =
1613 image::imageops::crop_imm(page_image, x1u, y1u, crop_w, crop_h).to_image();
1614
1615 let rec_input = ImageTaskInput::new(vec![crop]);
1616 let rec_result = text_recognition_adapter.execute(rec_input, None)?;
1617 if let (Some(text), Some(score)) =
1618 (rec_result.texts.first(), rec_result.scores.first())
1619 && !text.is_empty()
1620 {
1621 let bbox = BoundingBox::from_coords(
1622 box_coords[0],
1623 box_coords[1],
1624 box_coords[2],
1625 box_coords[3],
1626 );
1627 new_regions.push(crate::oarocr::TextRegion {
1628 bounding_box: bbox.clone(),
1629 dt_poly: Some(bbox.clone()),
1630 rec_poly: Some(bbox),
1631 text: Some(Arc::from(text.as_str())),
1632 confidence: Some(*score),
1633 orientation_angle: None,
1634 word_boxes: None,
1635 label: None,
1636 });
1637 }
1638 }
1639 }
1640
1641 *text_regions = new_regions;
1642 Ok(())
1643 }
1644
1645 fn detect_layout_and_regions(
1646 &self,
1647 page_image: &image::RgbImage,
1648 ) -> Result<
1649 (
1650 Vec<crate::domain::structure::LayoutElement>,
1651 Option<Vec<crate::domain::structure::RegionBlock>>,
1652 ),
1653 OCRError,
1654 > {
1655 use oar_ocr_core::core::traits::task::ImageTaskInput;
1656 use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType, RegionBlock};
1657
1658 let input = ImageTaskInput::new(vec![page_image.clone()]);
1659 let layout_result = self
1660 .pipeline
1661 .layout_detection_adapter
1662 .execute(input, None)?;
1663
1664 let mut layout_elements: Vec<LayoutElement> = Vec::new();
1665 if let Some(elements) = layout_result.elements.first() {
1666 for element in elements {
1667 let element_type_enum = LayoutElementType::from_label(&element.element_type);
1668 layout_elements.push(
1669 LayoutElement::new(element.bbox.clone(), element_type_enum, element.score)
1670 .with_label(element.element_type.clone()),
1671 );
1672 }
1673 }
1674
1675 let mut detected_region_blocks: Option<Vec<RegionBlock>> = None;
1676 if let Some(ref region_adapter) = self.pipeline.region_detection_adapter {
1677 let region_input = ImageTaskInput::new(vec![page_image.clone()]);
1678 if let Ok(region_result) = region_adapter.execute(region_input, None)
1679 && let Some(region_elements) = region_result.elements.first()
1680 && !region_elements.is_empty()
1681 {
1682 let blocks: Vec<RegionBlock> = region_elements
1683 .iter()
1684 .map(|e| RegionBlock {
1685 bbox: e.bbox.clone(),
1686 confidence: e.score,
1687 order_index: None,
1688 element_indices: Vec::new(),
1689 })
1690 .collect();
1691 detected_region_blocks = Some(blocks);
1692 }
1693 }
1694
1695 if layout_elements.len() > 1 {
1696 let removed = crate::domain::structure::remove_overlapping_layout_elements(
1697 &mut layout_elements,
1698 LAYOUT_OVERLAP_IOU_THRESHOLD,
1699 );
1700 if removed > 0 {
1701 tracing::info!(
1702 "Removing {} overlapping layout elements (threshold={})",
1703 removed,
1704 LAYOUT_OVERLAP_IOU_THRESHOLD
1705 );
1706 }
1707 }
1708
1709 crate::domain::structure::apply_standardized_layout_label_fixes(&mut layout_elements);
1710
1711 Ok((layout_elements, detected_region_blocks))
1712 }
1713
1714 fn recognize_formulas(
1715 &self,
1716 page_image: &image::RgbImage,
1717 layout_elements: &[crate::domain::structure::LayoutElement],
1718 ) -> Result<Vec<crate::domain::structure::FormulaResult>, OCRError> {
1719 use oar_ocr_core::core::traits::task::ImageTaskInput;
1720 use oar_ocr_core::domain::structure::FormulaResult;
1721 use oar_ocr_core::utils::BBoxCrop;
1722
1723 let Some(ref formula_adapter) = self.pipeline.formula_recognition_adapter else {
1724 return Ok(Vec::new());
1725 };
1726
1727 let formula_elements: Vec<_> = layout_elements
1728 .iter()
1729 .filter(|e| e.element_type.is_formula())
1730 .collect();
1731
1732 if formula_elements.is_empty() {
1733 tracing::debug!(
1734 "Formula recognition skipped: no formula regions from layout detection"
1735 );
1736 return Ok(Vec::new());
1737 }
1738
1739 let mut crops = Vec::new();
1740 let mut bboxes = Vec::new();
1741
1742 for elem in &formula_elements {
1743 match BBoxCrop::crop_bounding_box(page_image, &elem.bbox) {
1744 Ok(crop) => {
1745 crops.push(crop);
1746 bboxes.push(elem.bbox.clone());
1747 }
1748 Err(err) => {
1749 tracing::warn!("Formula region crop failed: {}", err);
1750 }
1751 }
1752 }
1753
1754 if crops.is_empty() {
1755 tracing::debug!(
1756 "Formula recognition skipped: all formula crops failed for {} regions",
1757 formula_elements.len()
1758 );
1759 return Ok(Vec::new());
1760 }
1761
1762 let input = ImageTaskInput::new(crops);
1763 let formula_result = formula_adapter.execute(input, None)?;
1764
1765 let mut formulas = Vec::new();
1766 for ((bbox, formula), score) in bboxes
1767 .into_iter()
1768 .zip(formula_result.formulas.into_iter())
1769 .zip(formula_result.scores.into_iter())
1770 {
1771 let width = bbox.x_max() - bbox.x_min();
1772 let height = bbox.y_max() - bbox.y_min();
1773 if width <= 0.0 || height <= 0.0 {
1774 tracing::warn!(
1775 "Skipping formula with non-positive bbox dimensions: w={:.2}, h={:.2}",
1776 width,
1777 height
1778 );
1779 continue;
1780 }
1781
1782 formulas.push(FormulaResult {
1783 bbox,
1784 latex: formula,
1785 confidence: score.unwrap_or(0.0),
1786 });
1787 }
1788
1789 Ok(formulas)
1790 }
1791
1792 fn detect_seal_text(
1793 &self,
1794 page_image: &image::RgbImage,
1795 layout_elements: &mut Vec<crate::domain::structure::LayoutElement>,
1796 ) -> Result<(), OCRError> {
1797 use oar_ocr_core::core::traits::task::ImageTaskInput;
1798 use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType};
1799 use oar_ocr_core::processors::Point;
1800 use oar_ocr_core::utils::BBoxCrop;
1801
1802 let Some(ref seal_adapter) = self.pipeline.seal_text_detection_adapter else {
1803 return Ok(());
1804 };
1805
1806 let seal_regions: Vec<_> = layout_elements
1807 .iter()
1808 .filter(|e| e.element_type == LayoutElementType::Seal)
1809 .map(|e| e.bbox.clone())
1810 .collect();
1811
1812 if seal_regions.is_empty() {
1813 tracing::debug!("Seal detection skipped: no seal regions from layout detection");
1814 return Ok(());
1815 }
1816
1817 let mut seal_crops = Vec::new();
1818 let mut crop_offsets = Vec::new();
1819
1820 for region_bbox in &seal_regions {
1821 match BBoxCrop::crop_bounding_box(page_image, region_bbox) {
1822 Ok(crop) => {
1823 seal_crops.push(crop);
1824 crop_offsets.push((region_bbox.x_min(), region_bbox.y_min()));
1825 }
1826 Err(err) => {
1827 tracing::warn!("Seal region crop failed: {}", err);
1828 }
1829 }
1830 }
1831
1832 if seal_crops.is_empty() {
1833 return Ok(());
1834 }
1835
1836 let input = ImageTaskInput::new(seal_crops);
1837 let seal_result = seal_adapter.execute(input, None)?;
1838
1839 for ((dx, dy), detections) in crop_offsets.iter().zip(seal_result.detections.into_iter()) {
1840 for detection in detections {
1841 let translated_bbox = crate::processors::BoundingBox::new(
1842 detection
1843 .bbox
1844 .points
1845 .iter()
1846 .map(|p| Point::new(p.x + dx, p.y + dy))
1847 .collect(),
1848 );
1849
1850 layout_elements.push(
1851 LayoutElement::new(translated_bbox, LayoutElementType::Seal, detection.score)
1852 .with_label("seal".to_string()),
1853 );
1854 }
1855 }
1856
1857 Ok(())
1858 }
1859
1860 fn sort_layout_elements_enhanced(
1861 layout_elements: &mut Vec<crate::domain::structure::LayoutElement>,
1862 page_width: f32,
1863 page_height: f32,
1864 ) {
1865 use oar_ocr_core::processors::layout_sorting::{SortableElement, sort_layout_enhanced};
1866
1867 if layout_elements.is_empty() {
1868 return;
1869 }
1870
1871 let sortable_elements: Vec<_> = layout_elements
1872 .iter()
1873 .map(|e| SortableElement {
1874 bbox: e.bbox.clone(),
1875 element_type: e.element_type,
1876 num_lines: e.num_lines,
1877 })
1878 .collect();
1879
1880 let sorted_indices = sort_layout_enhanced(&sortable_elements, page_width, page_height);
1881 if sorted_indices.len() != layout_elements.len() {
1882 return;
1883 }
1884
1885 let sorted_elements: Vec<_> = sorted_indices
1886 .into_iter()
1887 .map(|idx| layout_elements[idx].clone())
1888 .collect();
1889 *layout_elements = sorted_elements;
1890 }
1891
1892 fn assign_region_block_membership(
1893 region_blocks: &mut [crate::domain::structure::RegionBlock],
1894 layout_elements: &[crate::domain::structure::LayoutElement],
1895 ) {
1896 use std::cmp::Ordering;
1897
1898 if region_blocks.is_empty() {
1899 return;
1900 }
1901
1902 region_blocks.sort_by(|a, b| {
1903 a.bbox
1904 .y_min()
1905 .partial_cmp(&b.bbox.y_min())
1906 .unwrap_or(Ordering::Equal)
1907 .then_with(|| {
1908 a.bbox
1909 .x_min()
1910 .partial_cmp(&b.bbox.x_min())
1911 .unwrap_or(Ordering::Equal)
1912 })
1913 });
1914
1915 for (i, region) in region_blocks.iter_mut().enumerate() {
1916 region.order_index = Some((i + 1) as u32);
1917 region.element_indices.clear();
1918 }
1919
1920 if layout_elements.is_empty() {
1921 return;
1922 }
1923
1924 for (elem_idx, elem) in layout_elements.iter().enumerate() {
1925 let elem_area = elem.bbox.area();
1926 if elem_area <= 0.0 {
1927 continue;
1928 }
1929
1930 let mut best_region: Option<usize> = None;
1931 let mut best_ioa = 0.0f32;
1932
1933 for (region_idx, region) in region_blocks.iter().enumerate() {
1934 let intersection = elem.bbox.intersection_area(®ion.bbox);
1935 if intersection <= 0.0 {
1936 continue;
1937 }
1938 let ioa = intersection / elem_area;
1939 if ioa > best_ioa {
1940 best_ioa = ioa;
1941 best_region = Some(region_idx);
1942 }
1943 }
1944
1945 if let Some(region_idx) = best_region
1946 && best_ioa >= REGION_MEMBERSHIP_IOA_THRESHOLD
1947 {
1948 region_blocks[region_idx].element_indices.push(elem_idx);
1949 }
1950 }
1951 }
1952
1953 fn run_overall_ocr(
1954 &self,
1955 page_image: &image::RgbImage,
1956 layout_elements: &[crate::domain::structure::LayoutElement],
1957 region_blocks: Option<&[crate::domain::structure::RegionBlock]>,
1958 ) -> Result<Vec<crate::oarocr::TextRegion>, OCRError> {
1959 use crate::oarocr::TextRegion;
1960 use oar_ocr_core::core::traits::task::ImageTaskInput;
1961 use std::sync::Arc;
1962
1963 let Some(ref text_detection_adapter) = self.pipeline.text_detection_adapter else {
1964 return Ok(Vec::new());
1965 };
1966 let Some(ref text_recognition_adapter) = self.pipeline.text_recognition_adapter else {
1967 return Ok(Vec::new());
1968 };
1969
1970 let mut text_regions = Vec::new();
1971
1972 let mut ocr_image = page_image.clone();
1974 let mask_bboxes: Vec<crate::processors::BoundingBox> = layout_elements
1975 .iter()
1976 .filter(|e| e.element_type.is_formula())
1977 .map(|e| e.bbox.clone())
1978 .collect();
1979
1980 if !mask_bboxes.is_empty() {
1981 crate::utils::mask_regions(&mut ocr_image, &mask_bboxes, [255, 255, 255]);
1982 }
1983
1984 let input = ImageTaskInput::new(vec![ocr_image.clone()]);
1986 let det_result = text_detection_adapter.execute(input, None)?;
1987
1988 let mut detection_boxes = if let Some(detections) = det_result.detections.first() {
1989 detections
1990 .iter()
1991 .map(|d| d.bbox.clone())
1992 .collect::<Vec<_>>()
1993 } else {
1994 Vec::new()
1995 };
1996
1997 let raw_detection_boxes = detection_boxes.clone();
1999 if tracing::enabled!(tracing::Level::DEBUG) && !raw_detection_boxes.is_empty() {
2000 let raw_rects: Vec<[f32; 4]> = raw_detection_boxes
2001 .iter()
2002 .map(|b| [b.x_min(), b.y_min(), b.x_max(), b.y_max()])
2003 .collect();
2004 tracing::debug!("overall OCR text det boxes (raw): {:?}", raw_rects);
2005 }
2006
2007 if !detection_boxes.is_empty() {
2009 let mut split_boxes = Vec::new();
2010 let mut split_count = 0usize;
2011
2012 let container_boxes: Vec<crate::processors::BoundingBox> =
2013 if let Some(regions) = region_blocks {
2014 regions.iter().map(|r| r.bbox.clone()).collect()
2015 } else {
2016 layout_elements
2017 .iter()
2018 .filter(|e| {
2019 matches!(
2020 e.element_type,
2021 crate::domain::structure::LayoutElementType::DocTitle
2022 | crate::domain::structure::LayoutElementType::ParagraphTitle
2023 | crate::domain::structure::LayoutElementType::Text
2024 | crate::domain::structure::LayoutElementType::Content
2025 | crate::domain::structure::LayoutElementType::Abstract
2026 | crate::domain::structure::LayoutElementType::Header
2027 | crate::domain::structure::LayoutElementType::Footer
2028 | crate::domain::structure::LayoutElementType::Footnote
2029 | crate::domain::structure::LayoutElementType::Number
2030 | crate::domain::structure::LayoutElementType::Reference
2031 | crate::domain::structure::LayoutElementType::ReferenceContent
2032 | crate::domain::structure::LayoutElementType::Algorithm
2033 | crate::domain::structure::LayoutElementType::AsideText
2034 | crate::domain::structure::LayoutElementType::List
2035 | crate::domain::structure::LayoutElementType::FigureTitle
2036 | crate::domain::structure::LayoutElementType::TableTitle
2037 | crate::domain::structure::LayoutElementType::ChartTitle
2038 | crate::domain::structure::LayoutElementType::FigureTableChartTitle
2039 )
2040 })
2041 .map(|e| e.bbox.clone())
2042 .collect()
2043 };
2044
2045 if !container_boxes.is_empty() {
2046 for bbox in detection_boxes.into_iter() {
2047 let mut intersections: Vec<crate::processors::BoundingBox> = Vec::new();
2048 let self_area = bbox.area();
2049 if self_area <= 0.0 {
2050 split_boxes.push(bbox);
2051 continue;
2052 }
2053
2054 for container in &container_boxes {
2055 let inter_x_min = bbox.x_min().max(container.x_min());
2056 let inter_y_min = bbox.y_min().max(container.y_min());
2057 let inter_x_max = bbox.x_max().min(container.x_max());
2058 let inter_y_max = bbox.y_max().min(container.y_max());
2059
2060 if inter_x_max - inter_x_min <= 2.0 || inter_y_max - inter_y_min <= 2.0 {
2061 continue;
2062 }
2063
2064 let inter_bbox = crate::processors::BoundingBox::from_coords(
2065 inter_x_min,
2066 inter_y_min,
2067 inter_x_max,
2068 inter_y_max,
2069 );
2070 let inter_area = inter_bbox.area();
2071 if inter_area <= 0.0 {
2072 continue;
2073 }
2074
2075 let ioa = inter_area / self_area;
2076 if ioa >= TEXT_BOX_SPLIT_IOA_THRESHOLD {
2077 intersections.push(inter_bbox);
2078 }
2079 }
2080
2081 if intersections.len() >= 2 {
2082 split_count += intersections.len();
2083 split_boxes.extend(intersections);
2084 } else {
2085 split_boxes.push(bbox);
2086 }
2087 }
2088
2089 if split_count > 0 {
2090 tracing::debug!(
2091 "Cross-layout re-recognition: split {} text boxes into {} sub-boxes",
2092 split_count,
2093 split_boxes.len()
2094 );
2095 }
2096
2097 detection_boxes = split_boxes;
2098 }
2099 }
2100
2101 if !detection_boxes.is_empty() {
2103 detection_boxes = oar_ocr_core::processors::sort_quad_boxes(&detection_boxes);
2104 }
2105
2106 if tracing::enabled!(tracing::Level::DEBUG) && !detection_boxes.is_empty() {
2108 let pre_rec_rects: Vec<[f32; 4]> = detection_boxes
2109 .iter()
2110 .map(|b| [b.x_min(), b.y_min(), b.x_max(), b.y_max()])
2111 .collect();
2112 tracing::debug!(
2113 "overall OCR boxes pre-recognition (after splitting): {:?}",
2114 pre_rec_rects
2115 );
2116 }
2117
2118 if !detection_boxes.is_empty() {
2119 use crate::oarocr::processors::{EdgeProcessor, TextCroppingProcessor};
2120
2121 let processor = TextCroppingProcessor::new(true);
2122 let cropped =
2123 processor.process((Arc::new(page_image.clone()), detection_boxes.clone()))?;
2124
2125 let mut cropped_images: Vec<image::RgbImage> = Vec::new();
2126 let mut valid_indices: Vec<usize> = Vec::new();
2127
2128 for (idx, crop_result) in cropped.into_iter().enumerate() {
2129 if let Some(img) = crop_result {
2130 cropped_images.push((*img).clone());
2131 valid_indices.push(idx);
2132 }
2133 }
2134
2135 if !cropped_images.is_empty() {
2136 if let Some(ref tlo_adapter) = self.pipeline.text_line_orientation_adapter {
2138 let tlo_input = ImageTaskInput::new(cropped_images.clone());
2139 if let Ok(tlo_result) = tlo_adapter.execute(tlo_input, None) {
2140 for (i, classifications) in tlo_result.classifications.iter().enumerate() {
2141 if i >= cropped_images.len() {
2142 break;
2143 }
2144 if let Some(top_cls) = classifications.first()
2145 && top_cls.class_id == 1
2146 {
2147 cropped_images[i] = image::imageops::rotate180(&cropped_images[i]);
2148 }
2149 }
2150 }
2151 }
2152
2153 let mut items: Vec<(usize, f32, image::RgbImage)> = valid_indices
2154 .into_iter()
2155 .zip(cropped_images)
2156 .map(|(det_idx, img)| {
2157 let wh_ratio = img.width() as f32 / img.height().max(1) as f32;
2158 (det_idx, wh_ratio, img)
2159 })
2160 .collect();
2161
2162 items.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
2163
2164 let batch_size = self.pipeline.region_batch_size.unwrap_or(8).max(1);
2165 let mut recognized_by_det_idx: Vec<Option<(String, f32)>> =
2166 vec![None; detection_boxes.len()];
2167
2168 while !items.is_empty() {
2169 let take_n = batch_size.min(items.len());
2170 let batch_items: Vec<(usize, f32, image::RgbImage)> =
2171 items.drain(0..take_n).collect();
2172
2173 let mut det_indices: Vec<usize> = Vec::with_capacity(batch_items.len());
2174 let mut rec_imgs: Vec<image::RgbImage> = Vec::with_capacity(batch_items.len());
2175 for (det_idx, _ratio, img) in batch_items {
2176 det_indices.push(det_idx);
2177 rec_imgs.push(img);
2178 }
2179
2180 let rec_input = ImageTaskInput::new(rec_imgs);
2181 if let Ok(rec_result) = text_recognition_adapter.execute(rec_input, None) {
2182 for ((det_idx, text), score) in det_indices
2183 .into_iter()
2184 .zip(rec_result.texts.into_iter())
2185 .zip(rec_result.scores.into_iter())
2186 {
2187 if text.is_empty() {
2188 continue;
2189 }
2190 if let Some(slot) = recognized_by_det_idx.get_mut(det_idx) {
2191 *slot = Some((text, score));
2192 }
2193 }
2194 }
2195 }
2196
2197 for (det_idx, rec) in recognized_by_det_idx.into_iter().enumerate() {
2199 let Some((text, score)) = rec else {
2200 continue;
2201 };
2202 let bbox = detection_boxes[det_idx].clone();
2203 text_regions.push(TextRegion {
2204 bounding_box: bbox.clone(),
2205 dt_poly: Some(bbox.clone()),
2206 rec_poly: Some(bbox),
2207 text: Some(Arc::from(text)),
2208 confidence: Some(score),
2209 orientation_angle: None,
2210 word_boxes: None,
2211 label: None,
2212 });
2213 }
2214 }
2215 }
2216
2217 let batch_size = self.pipeline.region_batch_size.unwrap_or(8).max(1);
2218 Self::refine_overall_ocr_with_layout(
2219 &mut text_regions,
2220 layout_elements,
2221 region_blocks,
2222 page_image,
2223 text_recognition_adapter,
2224 batch_size,
2225 )?;
2226
2227 Ok(text_regions)
2228 }
2229
2230 pub fn predict(&self, image_path: impl Into<PathBuf>) -> Result<StructureResult, OCRError> {
2240 let image_path = image_path.into();
2241
2242 let image = image::open(&image_path).map_err(|e| OCRError::InvalidInput {
2244 message: format!(
2245 "failed to load image from '{}': {}",
2246 image_path.display(),
2247 e
2248 ),
2249 })?;
2250
2251 let mut result = self.predict_image(image.to_rgb8())?;
2252 result.input_path = std::sync::Arc::from(image_path.to_string_lossy().as_ref());
2253 Ok(result)
2254 }
2255
2256 fn prepare_page(&self, image: image::RgbImage) -> Result<PreparedPage, OCRError> {
2259 use crate::oarocr::preprocess::DocumentPreprocessor;
2260 use std::sync::Arc;
2261
2262 let preprocessor = DocumentPreprocessor::new(
2263 self.pipeline.document_orientation_adapter.as_ref(),
2264 self.pipeline.rectification_adapter.as_ref(),
2265 );
2266 let preprocess = preprocessor.preprocess(Arc::new(image))?;
2267 let current_image = preprocess.image;
2268 let orientation_angle = preprocess.orientation_angle;
2269 let rectified_img = preprocess.rectified_img;
2270 let rotation = preprocess.rotation;
2271
2272 let (layout_elements, detected_region_blocks) =
2273 self.detect_layout_and_regions(¤t_image)?;
2274
2275 Ok(PreparedPage {
2276 current_image,
2277 orientation_angle,
2278 rectified_img,
2279 rotation,
2280 layout_elements,
2281 detected_region_blocks,
2282 })
2283 }
2284
2285 fn complete_page(
2288 &self,
2289 prepared: PreparedPage,
2290 mut formulas: Vec<crate::domain::structure::FormulaResult>,
2291 ) -> Result<StructureResult, OCRError> {
2292 use std::sync::Arc;
2293
2294 let PreparedPage {
2295 current_image,
2296 orientation_angle,
2297 rectified_img,
2298 rotation,
2299 mut layout_elements,
2300 mut detected_region_blocks,
2301 } = prepared;
2302
2303 let mut tables = Vec::new();
2304
2305 self.detect_seal_text(¤t_image, &mut layout_elements)?;
2306
2307 if !layout_elements.is_empty() {
2310 let (width, height) = if let Some(img) = &rectified_img {
2311 (img.width() as f32, img.height() as f32)
2312 } else {
2313 (current_image.width() as f32, current_image.height() as f32)
2314 };
2315 Self::sort_layout_elements_enhanced(&mut layout_elements, width, height);
2316 }
2317
2318 if let Some(ref mut regions) = detected_region_blocks {
2319 Self::assign_region_block_membership(regions, &layout_elements);
2320 }
2321
2322 let mut text_regions = self.run_overall_ocr(
2323 ¤t_image,
2324 &layout_elements,
2325 detected_region_blocks.as_deref(),
2326 )?;
2327
2328 {
2329 let analyzer = crate::oarocr::table_analyzer::TableAnalyzer::new(
2330 crate::oarocr::table_analyzer::TableAnalyzerConfig {
2331 table_classification_adapter: self
2332 .pipeline
2333 .table_classification_adapter
2334 .as_ref(),
2335 table_orientation_adapter: self.pipeline.table_orientation_adapter.as_ref(),
2336 table_structure_recognition_adapter: self
2337 .pipeline
2338 .table_structure_recognition_adapter
2339 .as_ref(),
2340 wired_table_structure_adapter: self
2341 .pipeline
2342 .wired_table_structure_adapter
2343 .as_ref(),
2344 wireless_table_structure_adapter: self
2345 .pipeline
2346 .wireless_table_structure_adapter
2347 .as_ref(),
2348 table_cell_detection_adapter: self
2349 .pipeline
2350 .table_cell_detection_adapter
2351 .as_ref(),
2352 wired_table_cell_adapter: self.pipeline.wired_table_cell_adapter.as_ref(),
2353 wireless_table_cell_adapter: self.pipeline.wireless_table_cell_adapter.as_ref(),
2354 use_e2e_wired_table_rec: self.pipeline.use_e2e_wired_table_rec,
2355 use_e2e_wireless_table_rec: self.pipeline.use_e2e_wireless_table_rec,
2356 use_wired_table_cells_trans_to_html: self
2357 .pipeline
2358 .use_wired_table_cells_trans_to_html,
2359 use_wireless_table_cells_trans_to_html: self
2360 .pipeline
2361 .use_wireless_table_cells_trans_to_html,
2362 },
2363 );
2364 tables.extend(analyzer.analyze_tables(
2365 ¤t_image,
2366 &layout_elements,
2367 &formulas,
2368 &text_regions,
2369 )?);
2370 }
2371
2372 let has_detection_backed_table_cells = tables.iter().any(|table| !table.is_e2e);
2380 if has_detection_backed_table_cells
2381 && !text_regions.is_empty()
2382 && let Some(ref text_rec_adapter) = self.pipeline.text_recognition_adapter
2383 {
2384 Self::split_ocr_bboxes_by_table_cells(
2385 &tables,
2386 &mut text_regions,
2387 ¤t_image,
2388 text_rec_adapter,
2389 )?;
2390 }
2391
2392 if let Some(rot) = rotation {
2395 let rotated_width = rot.rotated_width;
2396 let rotated_height = rot.rotated_height;
2397 let angle = rot.angle;
2398
2399 for element in &mut layout_elements {
2401 element.bbox =
2402 element
2403 .bbox
2404 .rotate_back_to_original(angle, rotated_width, rotated_height);
2405 }
2406
2407 for table in &mut tables {
2409 table.bbox =
2410 table
2411 .bbox
2412 .rotate_back_to_original(angle, rotated_width, rotated_height);
2413
2414 for cell in &mut table.cells {
2416 cell.bbox =
2417 cell.bbox
2418 .rotate_back_to_original(angle, rotated_width, rotated_height);
2419 }
2420 }
2421
2422 for formula in &mut formulas {
2424 formula.bbox =
2425 formula
2426 .bbox
2427 .rotate_back_to_original(angle, rotated_width, rotated_height);
2428 }
2429
2430 for region in &mut text_regions {
2432 region.dt_poly = region
2433 .dt_poly
2434 .take()
2435 .map(|poly| poly.rotate_back_to_original(angle, rotated_width, rotated_height));
2436 region.rec_poly = region
2437 .rec_poly
2438 .take()
2439 .map(|poly| poly.rotate_back_to_original(angle, rotated_width, rotated_height));
2440 region.bounding_box = region.bounding_box.rotate_back_to_original(
2441 angle,
2442 rotated_width,
2443 rotated_height,
2444 );
2445
2446 if let Some(ref word_boxes) = region.word_boxes {
2447 let transformed_word_boxes: Vec<_> = word_boxes
2448 .iter()
2449 .map(|wb| wb.rotate_back_to_original(angle, rotated_width, rotated_height))
2450 .collect();
2451 region.word_boxes = Some(transformed_word_boxes);
2452 }
2453 }
2454
2455 if let Some(ref mut regions) = detected_region_blocks {
2457 for region in regions.iter_mut() {
2458 region.bbox =
2459 region
2460 .bbox
2461 .rotate_back_to_original(angle, rotated_width, rotated_height);
2462 }
2463 }
2464 }
2465
2466 for formula in &formulas {
2472 let w = formula.bbox.x_max() - formula.bbox.x_min();
2473 let h = formula.bbox.y_max() - formula.bbox.y_min();
2474 if w > 1.0 && h > 1.0 {
2475 let mut region = crate::oarocr::TextRegion::new(formula.bbox.clone());
2476 region.text = Some(formula.latex.clone().into());
2477 region.confidence = Some(1.0);
2478 region.label = Some("formula".into()); text_regions.push(region);
2480 }
2481 }
2482
2483 let final_image = rectified_img.unwrap_or_else(|| Arc::new((*current_image).clone()));
2487 let mut result = StructureResult {
2488 input_path: Arc::from("memory"),
2489 index: 0,
2490 layout_elements,
2491 tables,
2492 formulas,
2493 text_regions: if text_regions.is_empty() {
2494 None
2495 } else {
2496 Some(text_regions)
2497 },
2498 orientation_angle,
2499 region_blocks: detected_region_blocks,
2500 rectified_img: Some(final_image),
2501 page_continuation_flags: None,
2502 };
2503
2504 use crate::oarocr::stitching::{ResultStitcher, StitchConfig};
2507 let stitch_cfg = StitchConfig::default();
2508 ResultStitcher::stitch_with_config(&mut result, &stitch_cfg);
2509
2510 Ok(result)
2511 }
2512
2513 pub fn predict_image(&self, image: image::RgbImage) -> Result<StructureResult, OCRError> {
2515 let prepared = self.prepare_page(image)?;
2516 let formulas =
2517 self.recognize_formulas(&prepared.current_image, &prepared.layout_elements)?;
2518 self.complete_page(prepared, formulas)
2519 }
2520
2521 pub fn predict_images(
2531 &self,
2532 images: Vec<image::RgbImage>,
2533 ) -> Vec<Result<StructureResult, OCRError>> {
2534 use oar_ocr_core::core::traits::task::ImageTaskInput;
2535 use oar_ocr_core::domain::structure::FormulaResult;
2536 use oar_ocr_core::utils::BBoxCrop;
2537
2538 if images.is_empty() {
2539 return Vec::new();
2540 }
2541
2542 let prepared_pages: Vec<Result<PreparedPage, OCRError>> = images
2545 .into_iter()
2546 .map(|image| self.prepare_page(image))
2547 .collect();
2548
2549 let num_pages = prepared_pages.len();
2551 let mut per_page_formulas: Vec<Vec<FormulaResult>> =
2552 (0..num_pages).map(|_| Vec::new()).collect();
2553
2554 if let Some(ref formula_adapter) = self.pipeline.formula_recognition_adapter {
2555 let mut all_crops: Vec<image::RgbImage> = Vec::new();
2556 let mut crop_meta: Vec<(usize, oar_ocr_core::processors::BoundingBox)> = Vec::new();
2557
2558 for (page_idx, prepared) in prepared_pages.iter().enumerate() {
2559 let prepared = match prepared {
2560 Ok(p) => p,
2561 Err(_) => continue,
2562 };
2563 for elem in prepared
2564 .layout_elements
2565 .iter()
2566 .filter(|e| e.element_type.is_formula())
2567 {
2568 match BBoxCrop::crop_bounding_box(&prepared.current_image, &elem.bbox) {
2569 Ok(crop) => {
2570 all_crops.push(crop);
2571 crop_meta.push((page_idx, elem.bbox.clone()));
2572 }
2573 Err(err) => {
2574 tracing::warn!("Formula region crop failed (batch): {}", err);
2575 }
2576 }
2577 }
2578 }
2579
2580 if !all_crops.is_empty() {
2581 let batch_size = formula_adapter.recommended_batch_size().max(1);
2582 let mut remaining_crops = all_crops;
2583 let mut meta_offset = 0;
2584
2585 while !remaining_crops.is_empty() {
2586 let chunk_len = batch_size.min(remaining_crops.len());
2587 let rest = remaining_crops.split_off(chunk_len);
2588 let chunk_vec = remaining_crops;
2589 remaining_crops = rest;
2590
2591 let chunk_meta = &crop_meta[meta_offset..meta_offset + chunk_len];
2592 match formula_adapter.execute(ImageTaskInput::new(chunk_vec), None) {
2593 Ok(formula_output) => {
2594 for ((page_idx, bbox), (formula_text, score)) in
2595 chunk_meta.iter().cloned().zip(
2596 formula_output
2597 .formulas
2598 .into_iter()
2599 .zip(formula_output.scores),
2600 )
2601 {
2602 let width = bbox.x_max() - bbox.x_min();
2603 let height = bbox.y_max() - bbox.y_min();
2604 if width > 0.0 && height > 0.0 {
2605 per_page_formulas[page_idx].push(FormulaResult {
2606 bbox,
2607 latex: formula_text,
2608 confidence: score.unwrap_or(0.0),
2609 });
2610 }
2611 }
2612 }
2613 Err(err) => {
2614 tracing::warn!("Batch formula recognition failed: {}", err);
2615 }
2616 }
2617 meta_offset += chunk_len;
2618 }
2619 }
2620 }
2621
2622 prepared_pages
2624 .into_iter()
2625 .zip(per_page_formulas)
2626 .map(|(prepared, formulas)| self.complete_page(prepared?, formulas))
2627 .collect()
2628 }
2629}
2630
2631#[cfg(test)]
2632mod tests {
2633 use super::*;
2634
2635 #[test]
2636 fn test_structure_builder_new() {
2637 let builder = OARStructureBuilder::new("models/layout.onnx");
2638 assert_eq!(
2639 builder.layout_detection_model,
2640 PathBuf::from("models/layout.onnx")
2641 );
2642 assert!(builder.table_classification_model.is_none());
2643 assert!(builder.formula_recognition_model.is_none());
2644 }
2645
2646 #[test]
2647 fn test_structure_builder_with_table_components() {
2648 let builder = OARStructureBuilder::new("models/layout.onnx")
2649 .with_table_classification("models/table_cls.onnx")
2650 .with_table_cell_detection("models/table_cell.onnx", "wired")
2651 .with_table_structure_recognition("models/table_struct.onnx", "wired")
2652 .table_structure_dict_path("models/table_structure_dict.txt");
2653
2654 assert!(builder.table_classification_model.is_some());
2655 assert!(builder.table_cell_detection_model.is_some());
2656 assert!(builder.table_structure_recognition_model.is_some());
2657 assert_eq!(builder.table_cell_detection_type, Some("wired".to_string()));
2658 assert_eq!(
2659 builder.table_structure_recognition_type,
2660 Some("wired".to_string())
2661 );
2662 assert_eq!(
2663 builder.table_structure_dict_path,
2664 Some(PathBuf::from("models/table_structure_dict.txt"))
2665 );
2666 }
2667
2668 #[test]
2669 fn test_structure_builder_with_formula() {
2670 let builder = OARStructureBuilder::new("models/layout.onnx").with_formula_recognition(
2671 "models/formula.onnx",
2672 "models/tokenizer.json",
2673 "pp_formulanet",
2674 );
2675
2676 assert!(builder.formula_recognition_model.is_some());
2677 assert!(builder.formula_tokenizer_path.is_some());
2678 assert_eq!(
2679 builder.formula_recognition_type,
2680 Some("pp_formulanet".to_string())
2681 );
2682 }
2683
2684 #[test]
2685 fn test_structure_builder_with_ocr() {
2686 let builder = OARStructureBuilder::new("models/layout.onnx").with_ocr(
2687 "models/det.onnx",
2688 "models/rec.onnx",
2689 "models/dict.txt",
2690 );
2691
2692 assert!(builder.text_detection_model.is_some());
2693 assert!(builder.text_recognition_model.is_some());
2694 assert!(builder.character_dict_path.is_some());
2695 }
2696
2697 #[test]
2698 fn test_structure_builder_with_configuration() {
2699 let layout_config = LayoutDetectionConfig {
2700 score_threshold: 0.5,
2701 max_elements: 100,
2702 ..Default::default()
2703 };
2704
2705 let builder = OARStructureBuilder::new("models/layout.onnx")
2706 .layout_detection_config(layout_config.clone())
2707 .image_batch_size(4)
2708 .region_batch_size(64);
2709
2710 assert!(builder.layout_detection_config.is_some());
2711 assert_eq!(builder.image_batch_size, Some(4));
2712 assert_eq!(builder.region_batch_size, Some(64));
2713 }
2714}