1use super::builder_utils::build_optional_adapter;
8use oar_ocr_core::core::OCRError;
9use oar_ocr_core::core::config::OrtSessionConfig;
10use oar_ocr_core::core::traits::OrtConfigurable;
11use oar_ocr_core::core::traits::adapter::{AdapterBuilder, ModelAdapter};
12use oar_ocr_core::domain::adapters::{
13 DocumentOrientationAdapter, DocumentOrientationAdapterBuilder, FormulaRecognitionAdapter,
14 LayoutDetectionAdapter, LayoutDetectionAdapterBuilder, PPFormulaNetAdapterBuilder,
15 SLANetWiredAdapterBuilder, SLANetWirelessAdapterBuilder, SealTextDetectionAdapter,
16 SealTextDetectionAdapterBuilder, TableCellDetectionAdapter, TableCellDetectionAdapterBuilder,
17 TableClassificationAdapter, TableClassificationAdapterBuilder,
18 TableStructureRecognitionAdapter, TextDetectionAdapter, TextDetectionAdapterBuilder,
19 TextLineOrientationAdapter, TextLineOrientationAdapterBuilder, TextRecognitionAdapter,
20 TextRecognitionAdapterBuilder, UVDocRectifierAdapter, UVDocRectifierAdapterBuilder,
21 UniMERNetAdapterBuilder,
22};
23use oar_ocr_core::domain::structure::{StructureResult, TableResult};
24use oar_ocr_core::domain::tasks::{
25 FormulaRecognitionConfig, LayoutDetectionConfig, TableCellDetectionConfig,
26 TableClassificationConfig, TableStructureRecognitionConfig, TextDetectionConfig,
27 TextRecognitionConfig,
28};
29use std::path::PathBuf;
30use std::sync::Arc;
31
32const LAYOUT_OVERLAP_IOU_THRESHOLD: f32 = 0.5;
34
35const CELL_OVERLAP_IOU_THRESHOLD: f32 = 0.5;
37
38const REGION_MEMBERSHIP_IOA_THRESHOLD: f32 = 0.1;
41
42const TEXT_BOX_SPLIT_IOA_THRESHOLD: f32 = 0.3;
45
46#[derive(Debug)]
48struct StructurePipeline {
49 document_orientation_adapter: Option<DocumentOrientationAdapter>,
51 rectification_adapter: Option<UVDocRectifierAdapter>,
52
53 layout_detection_adapter: LayoutDetectionAdapter,
55
56 region_detection_adapter: Option<LayoutDetectionAdapter>,
58
59 table_classification_adapter: Option<TableClassificationAdapter>,
61 table_orientation_adapter: Option<DocumentOrientationAdapter>, table_cell_detection_adapter: Option<TableCellDetectionAdapter>,
63 table_structure_recognition_adapter: Option<TableStructureRecognitionAdapter>,
64 wired_table_structure_adapter: Option<TableStructureRecognitionAdapter>,
66 wireless_table_structure_adapter: Option<TableStructureRecognitionAdapter>,
67 wired_table_cell_adapter: Option<TableCellDetectionAdapter>,
68 wireless_table_cell_adapter: Option<TableCellDetectionAdapter>,
69 use_e2e_wired_table_rec: bool,
71 use_e2e_wireless_table_rec: bool,
72
73 formula_recognition_adapter: Option<FormulaRecognitionAdapter>,
74
75 seal_text_detection_adapter: Option<SealTextDetectionAdapter>,
76
77 text_detection_adapter: Option<TextDetectionAdapter>,
79 text_line_orientation_adapter: Option<TextLineOrientationAdapter>,
80 text_recognition_adapter: Option<TextRecognitionAdapter>,
81
82 region_batch_size: Option<usize>,
84}
85
86#[derive(Debug, Clone)]
120pub struct OARStructureBuilder {
121 layout_detection_model: PathBuf,
123 layout_model_name: Option<String>,
124
125 document_orientation_model: Option<PathBuf>,
127 document_rectification_model: Option<PathBuf>,
128
129 region_detection_model: Option<PathBuf>,
131
132 table_classification_model: Option<PathBuf>,
134 table_orientation_model: Option<PathBuf>, table_cell_detection_model: Option<PathBuf>,
136 table_cell_detection_type: Option<String>, table_structure_recognition_model: Option<PathBuf>,
138 table_structure_recognition_type: Option<String>, table_structure_dict_path: Option<PathBuf>,
140
141 wired_table_structure_model: Option<PathBuf>,
142 wireless_table_structure_model: Option<PathBuf>,
143 wired_table_cell_model: Option<PathBuf>,
144 wireless_table_cell_model: Option<PathBuf>,
145 use_e2e_wired_table_rec: bool,
148 use_e2e_wireless_table_rec: bool,
149
150 formula_recognition_model: Option<PathBuf>,
152 formula_recognition_type: Option<String>, formula_tokenizer_path: Option<PathBuf>,
154
155 seal_text_detection_model: Option<PathBuf>,
157
158 text_detection_model: Option<PathBuf>,
160 text_line_orientation_model: Option<PathBuf>,
161 text_recognition_model: Option<PathBuf>,
162 character_dict_path: Option<PathBuf>,
163
164 region_model_name: Option<String>,
166 wired_table_structure_model_name: Option<String>,
167 wireless_table_structure_model_name: Option<String>,
168 wired_table_cell_model_name: Option<String>,
169 wireless_table_cell_model_name: Option<String>,
170 text_detection_model_name: Option<String>,
171 text_recognition_model_name: Option<String>,
172
173 ort_session_config: Option<OrtSessionConfig>,
175 layout_detection_config: Option<LayoutDetectionConfig>,
176 table_classification_config: Option<TableClassificationConfig>,
177 table_cell_detection_config: Option<TableCellDetectionConfig>,
178 table_structure_recognition_config: Option<TableStructureRecognitionConfig>,
179 formula_recognition_config: Option<FormulaRecognitionConfig>,
180 text_detection_config: Option<TextDetectionConfig>,
181 text_recognition_config: Option<TextRecognitionConfig>,
182
183 image_batch_size: Option<usize>,
185 region_batch_size: Option<usize>,
186}
187
188impl OARStructureBuilder {
189 pub fn new(layout_detection_model: impl Into<PathBuf>) -> Self {
195 Self {
196 layout_detection_model: layout_detection_model.into(),
197 layout_model_name: None,
198 document_orientation_model: None,
199 document_rectification_model: None,
200 region_detection_model: None,
201 table_classification_model: None,
202 table_orientation_model: None,
203 table_cell_detection_model: None,
204 table_cell_detection_type: None,
205 table_structure_recognition_model: None,
206 table_structure_recognition_type: None,
207 table_structure_dict_path: None,
208 wired_table_structure_model: None,
209 wireless_table_structure_model: None,
210 wired_table_cell_model: None,
211 wireless_table_cell_model: None,
212 use_e2e_wired_table_rec: false,
214 use_e2e_wireless_table_rec: true,
215 formula_recognition_model: None,
216 formula_recognition_type: None,
217 formula_tokenizer_path: None,
218 seal_text_detection_model: None,
219 text_detection_model: None,
220 text_line_orientation_model: None,
221 text_recognition_model: None,
222 character_dict_path: None,
223 region_model_name: None,
224 wired_table_structure_model_name: None,
225 wireless_table_structure_model_name: None,
226 wired_table_cell_model_name: None,
227 wireless_table_cell_model_name: None,
228 text_detection_model_name: None,
229 text_recognition_model_name: None,
230 ort_session_config: None,
231 layout_detection_config: None,
232 table_classification_config: None,
233 table_cell_detection_config: None,
234 table_structure_recognition_config: None,
235 formula_recognition_config: None,
236 text_detection_config: None,
237 text_recognition_config: None,
238 image_batch_size: None,
239 region_batch_size: None,
240 }
241 }
242
243 pub fn ort_session(mut self, config: OrtSessionConfig) -> Self {
247 self.ort_session_config = Some(config);
248 self
249 }
250
251 pub fn layout_detection_config(mut self, config: LayoutDetectionConfig) -> Self {
253 self.layout_detection_config = Some(config);
254 self
255 }
256
257 pub fn layout_model_name(mut self, name: impl Into<String>) -> Self {
267 self.layout_model_name = Some(name.into());
268 self
269 }
270
271 pub fn region_model_name(mut self, name: impl Into<String>) -> Self {
276 self.region_model_name = Some(name.into());
277 self
278 }
279
280 pub fn wired_table_structure_model_name(mut self, name: impl Into<String>) -> Self {
284 self.wired_table_structure_model_name = Some(name.into());
285 self
286 }
287
288 pub fn wireless_table_structure_model_name(mut self, name: impl Into<String>) -> Self {
292 self.wireless_table_structure_model_name = Some(name.into());
293 self
294 }
295
296 pub fn wired_table_cell_model_name(mut self, name: impl Into<String>) -> Self {
300 self.wired_table_cell_model_name = Some(name.into());
301 self
302 }
303
304 pub fn wireless_table_cell_model_name(mut self, name: impl Into<String>) -> Self {
308 self.wireless_table_cell_model_name = Some(name.into());
309 self
310 }
311
312 pub fn text_detection_model_name(mut self, name: impl Into<String>) -> Self {
316 self.text_detection_model_name = Some(name.into());
317 self
318 }
319
320 pub fn text_recognition_model_name(mut self, name: impl Into<String>) -> Self {
324 self.text_recognition_model_name = Some(name.into());
325 self
326 }
327
328 pub fn image_batch_size(mut self, size: usize) -> Self {
332 self.image_batch_size = Some(size);
333 self
334 }
335
336 pub fn region_batch_size(mut self, size: usize) -> Self {
341 self.region_batch_size = Some(size);
342 self
343 }
344
345 pub fn with_document_orientation(mut self, model_path: impl Into<PathBuf>) -> Self {
350 self.document_orientation_model = Some(model_path.into());
351 self
352 }
353
354 pub fn with_document_rectification(mut self, model_path: impl Into<PathBuf>) -> Self {
359 self.document_rectification_model = Some(model_path.into());
360 self
361 }
362
363 pub fn with_region_detection(mut self, model_path: impl Into<PathBuf>) -> Self {
376 self.region_detection_model = Some(model_path.into());
377 self
378 }
379
380 pub fn with_seal_text_detection(mut self, model_path: impl Into<PathBuf>) -> Self {
385 self.seal_text_detection_model = Some(model_path.into());
386 self
387 }
388
389 pub fn with_table_classification(mut self, model_path: impl Into<PathBuf>) -> Self {
393 self.table_classification_model = Some(model_path.into());
394 self
395 }
396
397 pub fn table_classification_config(mut self, config: TableClassificationConfig) -> Self {
399 self.table_classification_config = Some(config);
400 self
401 }
402
403 pub fn with_table_orientation(mut self, model_path: impl Into<PathBuf>) -> Self {
413 self.table_orientation_model = Some(model_path.into());
414 self
415 }
416
417 pub fn use_e2e_wired_table_rec(mut self, enabled: bool) -> Self {
425 self.use_e2e_wired_table_rec = enabled;
426 self
427 }
428
429 pub fn use_e2e_wireless_table_rec(mut self, enabled: bool) -> Self {
437 self.use_e2e_wireless_table_rec = enabled;
438 self
439 }
440
441 pub fn with_table_cell_detection(
448 mut self,
449 model_path: impl Into<PathBuf>,
450 cell_type: impl Into<String>,
451 ) -> Self {
452 self.table_cell_detection_model = Some(model_path.into());
453 self.table_cell_detection_type = Some(cell_type.into());
454 self
455 }
456
457 pub fn table_cell_detection_config(mut self, config: TableCellDetectionConfig) -> Self {
459 self.table_cell_detection_config = Some(config);
460 self
461 }
462
463 pub fn with_table_structure_recognition(
472 mut self,
473 model_path: impl Into<PathBuf>,
474 table_type: impl Into<String>,
475 ) -> Self {
476 self.table_structure_recognition_model = Some(model_path.into());
477 self.table_structure_recognition_type = Some(table_type.into());
478 self
479 }
480
481 pub fn table_structure_dict_path(mut self, path: impl Into<PathBuf>) -> Self {
488 self.table_structure_dict_path = Some(path.into());
489 self
490 }
491
492 pub fn table_structure_recognition_config(
494 mut self,
495 config: TableStructureRecognitionConfig,
496 ) -> Self {
497 self.table_structure_recognition_config = Some(config);
498 self
499 }
500
501 pub fn with_wired_table_structure(mut self, model_path: impl Into<PathBuf>) -> Self {
506 self.wired_table_structure_model = Some(model_path.into());
507 self
508 }
509
510 pub fn with_wireless_table_structure(mut self, model_path: impl Into<PathBuf>) -> Self {
515 self.wireless_table_structure_model = Some(model_path.into());
516 self
517 }
518
519 pub fn with_wired_table_cell_detection(mut self, model_path: impl Into<PathBuf>) -> Self {
524 self.wired_table_cell_model = Some(model_path.into());
525 self
526 }
527
528 pub fn with_wireless_table_cell_detection(mut self, model_path: impl Into<PathBuf>) -> Self {
533 self.wireless_table_cell_model = Some(model_path.into());
534 self
535 }
536
537 pub fn with_formula_recognition(
547 mut self,
548 model_path: impl Into<PathBuf>,
549 tokenizer_path: impl Into<PathBuf>,
550 model_type: impl Into<String>,
551 ) -> Self {
552 self.formula_recognition_model = Some(model_path.into());
553 self.formula_tokenizer_path = Some(tokenizer_path.into());
554 self.formula_recognition_type = Some(model_type.into());
555 self
556 }
557
558 pub fn formula_recognition_config(mut self, config: FormulaRecognitionConfig) -> Self {
560 self.formula_recognition_config = Some(config);
561 self
562 }
563
564 pub fn with_ocr(
572 mut self,
573 text_detection_model: impl Into<PathBuf>,
574 text_recognition_model: impl Into<PathBuf>,
575 character_dict_path: impl Into<PathBuf>,
576 ) -> Self {
577 self.text_detection_model = Some(text_detection_model.into());
578 self.text_recognition_model = Some(text_recognition_model.into());
579 self.character_dict_path = Some(character_dict_path.into());
580 self
581 }
582
583 pub fn with_text_line_orientation(mut self, model_path: impl Into<PathBuf>) -> Self {
594 self.text_line_orientation_model = Some(model_path.into());
595 self
596 }
597
598 pub fn text_detection_config(mut self, config: TextDetectionConfig) -> Self {
600 self.text_detection_config = Some(config);
601 self
602 }
603
604 pub fn text_recognition_config(mut self, config: TextRecognitionConfig) -> Self {
606 self.text_recognition_config = Some(config);
607 self
608 }
609
610 pub fn build(self) -> Result<OARStructure, OCRError> {
614 let char_dict = if let Some(ref dict_path) = self.character_dict_path {
616 Some(
617 std::fs::read_to_string(dict_path).map_err(|e| OCRError::InvalidInput {
618 message: format!(
619 "Failed to read character dictionary from '{}': {}",
620 dict_path.display(),
621 e
622 ),
623 })?,
624 )
625 } else {
626 None
627 };
628
629 let document_orientation_adapter = build_optional_adapter(
631 self.document_orientation_model.as_ref(),
632 self.ort_session_config.as_ref(),
633 DocumentOrientationAdapterBuilder::new,
634 )?;
635
636 let rectification_adapter = build_optional_adapter(
638 self.document_rectification_model.as_ref(),
639 self.ort_session_config.as_ref(),
640 UVDocRectifierAdapterBuilder::new,
641 )?;
642
643 let mut layout_builder = LayoutDetectionAdapterBuilder::new();
645
646 let layout_model_config = if let Some(name) = &self.layout_model_name {
648 use oar_ocr_core::domain::adapters::LayoutModelConfig;
649 match name.as_str() {
650 "picodet_layout_1x" => LayoutModelConfig::picodet_layout_1x(),
651 "picodet_layout_1x_table" => LayoutModelConfig::picodet_layout_1x_table(),
652 "picodet_s_layout_3cls" => LayoutModelConfig::picodet_s_layout_3cls(),
653 "picodet_l_layout_3cls" => LayoutModelConfig::picodet_l_layout_3cls(),
654 "picodet_s_layout_17cls" => LayoutModelConfig::picodet_s_layout_17cls(),
655 "picodet_l_layout_17cls" => LayoutModelConfig::picodet_l_layout_17cls(),
656 "rt-detr-h_layout_3cls" => LayoutModelConfig::rtdetr_h_layout_3cls(),
657 "rt-detr-h_layout_17cls" => LayoutModelConfig::rtdetr_h_layout_17cls(),
658 "pp-docblocklayout" => LayoutModelConfig::pp_docblocklayout(),
659 "pp-doclayout-s" => LayoutModelConfig::pp_doclayout_s(),
660 "pp-doclayout-m" => LayoutModelConfig::pp_doclayout_m(),
661 "pp-doclayout-l" => LayoutModelConfig::pp_doclayout_l(),
662 "pp-doclayout_plus-l" => LayoutModelConfig::pp_doclayout_plus_l(),
663 _ => LayoutModelConfig::pp_doclayout_plus_l(),
664 }
665 } else {
666 crate::domain::adapters::LayoutModelConfig::pp_doclayout_plus_l()
668 };
669
670 layout_builder = layout_builder.model_config(layout_model_config);
671
672 let effective_layout_cfg = self
674 .layout_detection_config
675 .clone()
676 .unwrap_or_else(LayoutDetectionConfig::with_pp_structurev3_defaults);
677 layout_builder = layout_builder.with_config(effective_layout_cfg);
678
679 if let Some(ref ort_config) = self.ort_session_config {
680 layout_builder = layout_builder.with_ort_config(ort_config.clone());
681 }
682
683 let layout_detection_adapter = layout_builder.build(&self.layout_detection_model)?;
684
685 let region_detection_adapter = if let Some(ref model_path) = self.region_detection_model {
687 use oar_ocr_core::domain::adapters::LayoutModelConfig;
688 let mut region_builder = LayoutDetectionAdapterBuilder::new();
689
690 let region_model_config = if let Some(ref name) = self.region_model_name {
692 match name.to_lowercase().replace("-", "_").as_str() {
693 "pp_docblocklayout" => LayoutModelConfig::pp_docblocklayout(),
694 _ => LayoutModelConfig::pp_docblocklayout(),
695 }
696 } else {
697 LayoutModelConfig::pp_docblocklayout()
698 };
699 region_builder = region_builder.model_config(region_model_config);
700
701 let mut region_cfg = LayoutDetectionConfig::default();
703 let mut merge_modes = std::collections::HashMap::new();
704 merge_modes.insert(
705 "region".to_string(),
706 crate::domain::tasks::layout_detection::MergeBboxMode::Small,
707 );
708 region_cfg.class_merge_modes = Some(merge_modes);
709 region_builder = region_builder.with_config(region_cfg);
710
711 if let Some(ref ort_config) = self.ort_session_config {
712 region_builder = region_builder.with_ort_config(ort_config.clone());
713 }
714
715 Some(region_builder.build(model_path)?)
716 } else {
717 None
718 };
719
720 let table_classification_adapter =
722 if let Some(ref model_path) = self.table_classification_model {
723 let mut builder = TableClassificationAdapterBuilder::new();
724
725 if let Some(ref config) = self.table_classification_config {
726 builder = builder.with_config(config.clone());
727 }
728
729 if let Some(ref ort_config) = self.ort_session_config {
730 builder = builder.with_ort_config(ort_config.clone());
731 }
732
733 Some(builder.build(model_path)?)
734 } else {
735 None
736 };
737
738 let table_orientation_adapter = build_optional_adapter(
741 self.table_orientation_model.as_ref(),
742 self.ort_session_config.as_ref(),
743 DocumentOrientationAdapterBuilder::new,
744 )?;
745
746 let table_cell_detection_adapter = if let Some(ref model_path) =
748 self.table_cell_detection_model
749 {
750 let cell_type = self.table_cell_detection_type.as_deref().unwrap_or("wired");
751
752 use oar_ocr_core::domain::adapters::table_cell_detection_adapter::TableCellModelConfig;
753
754 let model_config = match cell_type {
755 "wired" => TableCellModelConfig::rtdetr_l_wired_table_cell_det(),
756 "wireless" => TableCellModelConfig::rtdetr_l_wireless_table_cell_det(),
757 _ => {
758 return Err(OCRError::config_error_detailed(
759 "table_cell_detection",
760 format!(
761 "Invalid cell type '{}': must be 'wired' or 'wireless'",
762 cell_type
763 ),
764 ));
765 }
766 };
767
768 let mut builder = TableCellDetectionAdapterBuilder::new().model_config(model_config);
769
770 if let Some(ref config) = self.table_cell_detection_config {
771 builder = builder.with_config(config.clone());
772 }
773
774 if let Some(ref ort_config) = self.ort_session_config {
775 builder = builder.with_ort_config(ort_config.clone());
776 }
777
778 Some(builder.build(model_path)?)
779 } else {
780 None
781 };
782
783 let table_structure_recognition_adapter = if let Some(ref model_path) =
785 self.table_structure_recognition_model
786 {
787 let table_type = self
788 .table_structure_recognition_type
789 .as_deref()
790 .unwrap_or("wired");
791 let dict_path = self
792 .table_structure_dict_path
793 .clone()
794 .ok_or_else(|| {
795 OCRError::config_error_detailed(
796 "table_structure_recognition",
797 "Dictionary path is required. Call table_structure_dict_path() when enabling table structure recognition.".to_string(),
798 )
799 })?;
800
801 let adapter: TableStructureRecognitionAdapter = match table_type {
802 "wired" => {
803 let mut builder = SLANetWiredAdapterBuilder::new().dict_path(dict_path.clone());
804
805 if let Some(ref config) = self.table_structure_recognition_config {
806 builder = builder.with_config(config.clone());
807 }
808
809 if let Some(ref ort_config) = self.ort_session_config {
810 builder = builder.with_ort_config(ort_config.clone());
811 }
812
813 builder.build(model_path)?
814 }
815 "wireless" => {
816 let mut builder =
817 SLANetWirelessAdapterBuilder::new().dict_path(dict_path.clone());
818
819 if let Some(ref config) = self.table_structure_recognition_config {
820 builder = builder.with_config(config.clone());
821 }
822
823 if let Some(ref ort_config) = self.ort_session_config {
824 builder = builder.with_ort_config(ort_config.clone());
825 }
826
827 builder.build(model_path)?
828 }
829 _ => {
830 return Err(OCRError::config_error_detailed(
831 "table_structure_recognition",
832 format!(
833 "Invalid table type '{}': must be 'wired' or 'wireless'",
834 table_type
835 ),
836 ));
837 }
838 };
839
840 Some(adapter)
841 } else {
842 None
843 };
844
845 let wired_table_structure_adapter = if let Some(ref model_path) =
847 self.wired_table_structure_model
848 {
849 let dict_path = self.table_structure_dict_path.clone().ok_or_else(|| {
850 OCRError::config_error_detailed(
851 "wired_table_structure",
852 "Dictionary path is required. Call table_structure_dict_path() when enabling table structure recognition.".to_string(),
853 )
854 })?;
855
856 let mut builder = SLANetWiredAdapterBuilder::new().dict_path(dict_path);
857
858 if let Some(ref config) = self.table_structure_recognition_config {
859 builder = builder.with_config(config.clone());
860 }
861
862 if let Some(ref ort_config) = self.ort_session_config {
863 builder = builder.with_ort_config(ort_config.clone());
864 }
865
866 Some(builder.build(model_path)?)
867 } else {
868 None
869 };
870
871 let wireless_table_structure_adapter = if let Some(ref model_path) =
872 self.wireless_table_structure_model
873 {
874 let dict_path = self.table_structure_dict_path.clone().ok_or_else(|| {
875 OCRError::config_error_detailed(
876 "wireless_table_structure",
877 "Dictionary path is required. Call table_structure_dict_path() when enabling table structure recognition.".to_string(),
878 )
879 })?;
880
881 let mut builder = SLANetWirelessAdapterBuilder::new().dict_path(dict_path);
882
883 if let Some(ref config) = self.table_structure_recognition_config {
884 builder = builder.with_config(config.clone());
885 }
886
887 if let Some(ref ort_config) = self.ort_session_config {
888 builder = builder.with_ort_config(ort_config.clone());
889 }
890
891 Some(builder.build(model_path)?)
892 } else {
893 None
894 };
895
896 let wired_table_cell_adapter = if let Some(ref model_path) = self.wired_table_cell_model {
898 use oar_ocr_core::domain::adapters::table_cell_detection_adapter::TableCellModelConfig;
899
900 let model_config = TableCellModelConfig::rtdetr_l_wired_table_cell_det();
901 let mut builder = TableCellDetectionAdapterBuilder::new().model_config(model_config);
902
903 if let Some(ref config) = self.table_cell_detection_config {
904 builder = builder.with_config(config.clone());
905 }
906
907 if let Some(ref ort_config) = self.ort_session_config {
908 builder = builder.with_ort_config(ort_config.clone());
909 }
910
911 Some(builder.build(model_path)?)
912 } else {
913 None
914 };
915
916 let wireless_table_cell_adapter = if let Some(ref model_path) =
917 self.wireless_table_cell_model
918 {
919 use oar_ocr_core::domain::adapters::table_cell_detection_adapter::TableCellModelConfig;
920
921 let model_config = TableCellModelConfig::rtdetr_l_wireless_table_cell_det();
922 let mut builder = TableCellDetectionAdapterBuilder::new().model_config(model_config);
923
924 if let Some(ref config) = self.table_cell_detection_config {
925 builder = builder.with_config(config.clone());
926 }
927
928 if let Some(ref ort_config) = self.ort_session_config {
929 builder = builder.with_ort_config(ort_config.clone());
930 }
931
932 Some(builder.build(model_path)?)
933 } else {
934 None
935 };
936
937 let formula_recognition_adapter = if let Some(ref model_path) =
939 self.formula_recognition_model
940 {
941 let tokenizer_path = self.formula_tokenizer_path.as_ref().ok_or_else(|| {
942 OCRError::config_error_detailed(
943 "formula_recognition",
944 "Tokenizer path is required for formula recognition".to_string(),
945 )
946 })?;
947
948 let model_type = self.formula_recognition_type.as_deref().ok_or_else(|| {
949 OCRError::config_error_detailed(
950 "formula_recognition",
951 "Model type is required (must be 'pp_formulanet' or 'unimernet')".to_string(),
952 )
953 })?;
954
955 let adapter: FormulaRecognitionAdapter = match model_type.to_lowercase().as_str() {
956 "pp_formulanet" | "pp-formulanet" => {
957 let mut builder = PPFormulaNetAdapterBuilder::new();
958
959 builder = builder.tokenizer_path(tokenizer_path);
960
961 if let Some(ref config) = self.formula_recognition_config {
964 builder = builder.task_config(config.clone());
965 }
966
967 if let Some(ref ort_config) = self.ort_session_config {
968 builder = builder.with_ort_config(ort_config.clone());
969 }
970
971 builder.build(model_path)?
972 }
973 "unimernet" => {
974 let mut builder = UniMERNetAdapterBuilder::new();
975
976 builder = builder.tokenizer_path(tokenizer_path);
977
978 if let Some(ref config) = self.formula_recognition_config {
981 builder = builder.task_config(config.clone());
982 }
983
984 if let Some(ref ort_config) = self.ort_session_config {
985 builder = builder.with_ort_config(ort_config.clone());
986 }
987
988 builder.build(model_path)?
989 }
990 _ => {
991 return Err(OCRError::config_error_detailed(
992 "formula_recognition",
993 format!(
994 "Invalid model type '{}': must be 'pp_formulanet' or 'unimernet'",
995 model_type
996 ),
997 ));
998 }
999 };
1000
1001 Some(adapter)
1002 } else {
1003 None
1004 };
1005
1006 let seal_text_detection_adapter =
1008 if let Some(ref model_path) = self.seal_text_detection_model {
1009 let mut builder = SealTextDetectionAdapterBuilder::new();
1010
1011 if let Some(ref ort_config) = self.ort_session_config {
1012 builder = builder.with_ort_config(ort_config.clone());
1013 }
1014
1015 Some(builder.build(model_path)?)
1016 } else {
1017 None
1018 };
1019
1020 let text_detection_adapter = if let Some(ref model_path) = self.text_detection_model {
1028 let mut builder = TextDetectionAdapterBuilder::new();
1029
1030 let mut effective_cfg = self.text_detection_config.clone().unwrap_or_default();
1033 if effective_cfg.limit_side_len.is_none() {
1034 effective_cfg.limit_side_len = Some(736);
1035 }
1036 if effective_cfg.limit_type.is_none() {
1037 effective_cfg.limit_type = Some(crate::processors::LimitType::Min);
1038 }
1039 builder = builder.with_config(effective_cfg);
1040
1041 if let Some(ref ort_config) = self.ort_session_config {
1042 builder = builder.with_ort_config(ort_config.clone());
1043 }
1044
1045 Some(builder.build(model_path)?)
1046 } else {
1047 None
1048 };
1049
1050 let text_line_orientation_adapter =
1052 if let Some(ref model_path) = self.text_line_orientation_model {
1053 let mut builder = TextLineOrientationAdapterBuilder::new();
1054
1055 if let Some(ref ort_config) = self.ort_session_config {
1056 builder = builder.with_ort_config(ort_config.clone());
1057 }
1058
1059 Some(builder.build(model_path)?)
1060 } else {
1061 None
1062 };
1063
1064 let text_recognition_adapter = if let Some(ref model_path) = self.text_recognition_model {
1066 let dict = char_dict.ok_or_else(|| OCRError::InvalidInput {
1067 message: "Character dictionary is required for text recognition".to_string(),
1068 })?;
1069
1070 let char_vec: Vec<String> = dict.lines().map(|s| s.to_string()).collect();
1072
1073 let mut builder = TextRecognitionAdapterBuilder::new().character_dict(char_vec);
1074
1075 if let Some(ref config) = self.text_recognition_config {
1078 builder = builder.with_config(config.clone());
1079 }
1080
1081 if let Some(ref ort_config) = self.ort_session_config {
1082 builder = builder.with_ort_config(ort_config.clone());
1083 }
1084
1085 Some(builder.build(model_path)?)
1086 } else {
1087 None
1088 };
1089
1090 let pipeline = StructurePipeline {
1091 document_orientation_adapter,
1092 rectification_adapter,
1093 layout_detection_adapter,
1094 region_detection_adapter,
1095 table_classification_adapter,
1096 table_orientation_adapter,
1097 table_cell_detection_adapter,
1098 table_structure_recognition_adapter,
1099 wired_table_structure_adapter,
1100 wireless_table_structure_adapter,
1101 wired_table_cell_adapter,
1102 wireless_table_cell_adapter,
1103 use_e2e_wired_table_rec: self.use_e2e_wired_table_rec,
1104 use_e2e_wireless_table_rec: self.use_e2e_wireless_table_rec,
1105 formula_recognition_adapter,
1106 seal_text_detection_adapter,
1107 text_detection_adapter,
1108 text_line_orientation_adapter,
1109 text_recognition_adapter,
1110 region_batch_size: self.region_batch_size,
1111 };
1112
1113 Ok(OARStructure { pipeline })
1114 }
1115}
1116
1117#[derive(Debug)]
1121pub struct OARStructure {
1122 pipeline: StructurePipeline,
1123}
1124
1125impl OARStructure {
1126 fn refine_overall_ocr_with_layout(
1137 text_regions: &mut Vec<crate::oarocr::TextRegion>,
1138 layout_elements: &[crate::domain::structure::LayoutElement],
1139 region_blocks: Option<&[crate::domain::structure::RegionBlock]>,
1140 page_image: &image::RgbImage,
1141 text_recognition_adapter: &TextRecognitionAdapter,
1142 region_batch_size: usize,
1143 ) -> Result<(), OCRError> {
1144 use oar_ocr_core::core::traits::task::ImageTaskInput;
1145 use oar_ocr_core::domain::structure::LayoutElementType;
1146 use oar_ocr_core::processors::BoundingBox;
1147 use oar_ocr_core::utils::BBoxCrop;
1148
1149 if text_regions.is_empty() || layout_elements.is_empty() {
1150 return Ok(());
1151 }
1152
1153 fn aabb_intersection(b1: &BoundingBox, b2: &BoundingBox) -> Option<BoundingBox> {
1154 let x1 = b1.x_min().max(b2.x_min());
1155 let y1 = b1.y_min().max(b2.y_min());
1156 let x2 = b1.x_max().min(b2.x_max());
1157 let y2 = b1.y_max().min(b2.y_max());
1158 if x2 - x1 <= 1.0 || y2 - y1 <= 1.0 {
1159 None
1160 } else {
1161 Some(BoundingBox::from_coords(x1, y1, x2, y2))
1162 }
1163 }
1164
1165 let is_excluded_layout = |t: LayoutElementType| {
1167 matches!(
1168 t,
1169 LayoutElementType::Formula
1170 | LayoutElementType::FormulaNumber
1171 | LayoutElementType::Table
1172 | LayoutElementType::Seal
1173 )
1174 };
1175
1176 let min_pixels = 3.0;
1180 let mut matched_ocr: Vec<Vec<usize>> = vec![Vec::new(); text_regions.len()];
1181 for (ocr_idx, region) in text_regions.iter().enumerate() {
1182 for (layout_idx, elem) in layout_elements.iter().enumerate() {
1183 if is_excluded_layout(elem.element_type) {
1184 continue;
1185 }
1186 let inter_x_min = region.bounding_box.x_min().max(elem.bbox.x_min());
1187 let inter_y_min = region.bounding_box.y_min().max(elem.bbox.y_min());
1188 let inter_x_max = region.bounding_box.x_max().min(elem.bbox.x_max());
1189 let inter_y_max = region.bounding_box.y_max().min(elem.bbox.y_max());
1190 if inter_x_max - inter_x_min > min_pixels && inter_y_max - inter_y_min > min_pixels
1191 {
1192 matched_ocr[ocr_idx].push(layout_idx);
1193 }
1194 }
1195 }
1196
1197 let mut appended_regions: Vec<crate::oarocr::TextRegion> = Vec::new();
1199 let original_ocr_len = text_regions.len();
1200 let mut multi_layout_ocr_count = 0usize;
1201 let mut multi_layout_crop_count = 0usize;
1202
1203 for ocr_idx in 0..original_ocr_len {
1204 let layout_ids = matched_ocr[ocr_idx].clone();
1205 if layout_ids.len() <= 1 {
1206 continue;
1207 }
1208 multi_layout_ocr_count += 1;
1209
1210 let ocr_box = text_regions[ocr_idx].bounding_box.clone();
1211
1212 let mut crops: Vec<image::RgbImage> = Vec::new();
1213 let mut crop_boxes: Vec<(BoundingBox, bool)> = Vec::new(); for (j, layout_idx) in layout_ids.iter().enumerate() {
1216 let layout_box = &layout_elements[*layout_idx].bbox;
1217 let Some(crop_box) = aabb_intersection(&ocr_box, layout_box) else {
1218 continue;
1219 };
1220
1221 for (other_idx, other_region) in text_regions.iter_mut().enumerate() {
1223 if other_idx == ocr_idx {
1224 continue;
1225 }
1226 if other_region.bounding_box.iou(&crop_box) > 0.8 {
1227 other_region.text = None;
1228 }
1229 }
1230
1231 if let Ok(crop_img) = BBoxCrop::crop_bounding_box(page_image, &crop_box) {
1232 crops.push(crop_img);
1233 crop_boxes.push((crop_box, j == 0));
1234 }
1235 }
1236 multi_layout_crop_count += crop_boxes.len();
1237
1238 if crops.is_empty() {
1239 continue;
1240 }
1241
1242 let mut rec_texts: Vec<String> = Vec::with_capacity(crops.len());
1244 let mut rec_scores: Vec<f32> = Vec::with_capacity(crops.len());
1245
1246 for batch_start in (0..crops.len()).step_by(region_batch_size.max(1)) {
1247 let batch_end = (batch_start + region_batch_size).min(crops.len());
1248 let batch: Vec<_> = crops[batch_start..batch_end].to_vec();
1249 let rec_input = ImageTaskInput::new(batch);
1250 let rec_result = text_recognition_adapter.execute(rec_input, None)?;
1251 rec_texts.extend(rec_result.texts);
1252 rec_scores.extend(rec_result.scores);
1253 }
1254
1255 for ((crop_box, is_first), (text, score)) in crop_boxes
1256 .into_iter()
1257 .zip(rec_texts.into_iter().zip(rec_scores.into_iter()))
1258 {
1259 if text.is_empty() {
1260 continue;
1261 }
1262 if is_first {
1263 text_regions[ocr_idx].bounding_box = crop_box.clone();
1264 text_regions[ocr_idx].dt_poly = Some(crop_box.clone());
1265 text_regions[ocr_idx].rec_poly = Some(crop_box.clone());
1266 text_regions[ocr_idx].text = Some(Arc::from(text));
1267 text_regions[ocr_idx].confidence = Some(score);
1268 } else {
1269 appended_regions.push(crate::oarocr::TextRegion {
1270 bounding_box: crop_box.clone(),
1271 dt_poly: Some(crop_box.clone()),
1272 rec_poly: Some(crop_box),
1273 text: Some(Arc::from(text)),
1274 confidence: Some(score),
1275 orientation_angle: None,
1276 word_boxes: None,
1277 });
1278 }
1279 }
1280 }
1281
1282 if !appended_regions.is_empty() {
1283 text_regions.extend(appended_regions);
1284 }
1285
1286 let mut fallback_blocks = 0usize;
1289 for elem in layout_elements.iter() {
1290 if is_excluded_layout(elem.element_type) {
1291 continue;
1292 }
1293 if matches!(
1294 elem.element_type,
1295 LayoutElementType::Image | LayoutElementType::Chart
1296 ) {
1297 continue;
1298 }
1299
1300 let mut has_text = false;
1301 for region in text_regions.iter() {
1302 if !region.text.as_ref().map(|t| !t.is_empty()).unwrap_or(false) {
1303 continue;
1304 }
1305 let inter_x_min = region.bounding_box.x_min().max(elem.bbox.x_min());
1306 let inter_y_min = region.bounding_box.y_min().max(elem.bbox.y_min());
1307 let inter_x_max = region.bounding_box.x_max().min(elem.bbox.x_max());
1308 let inter_y_max = region.bounding_box.y_max().min(elem.bbox.y_max());
1309 if inter_x_max - inter_x_min > min_pixels && inter_y_max - inter_y_min > min_pixels
1310 {
1311 has_text = true;
1312 break;
1313 }
1314 }
1315
1316 if has_text {
1317 continue;
1318 }
1319 fallback_blocks += 1;
1320
1321 if let Ok(crop_img) = BBoxCrop::crop_bounding_box(page_image, &elem.bbox) {
1323 let rec_input = ImageTaskInput::new(vec![crop_img]);
1324 let rec_result = text_recognition_adapter.execute(rec_input, None)?;
1325 if let (Some(text), Some(score)) =
1326 (rec_result.texts.first(), rec_result.scores.first())
1327 && !text.is_empty()
1328 {
1329 let crop_box = elem.bbox.clone();
1330 text_regions.push(crate::oarocr::TextRegion {
1331 bounding_box: crop_box.clone(),
1332 dt_poly: Some(crop_box.clone()),
1333 rec_poly: Some(crop_box),
1334 text: Some(Arc::from(text.as_str())),
1335 confidence: Some(*score),
1336 orientation_angle: None,
1337 word_boxes: None,
1338 });
1339 }
1340 }
1341 }
1342
1343 tracing::info!(
1344 "overall OCR refine: multi-layout OCR boxes={}, crops={}, fallback layout blocks={}",
1345 multi_layout_ocr_count,
1346 multi_layout_crop_count,
1347 fallback_blocks
1348 );
1349
1350 let _ = region_blocks;
1353
1354 Ok(())
1355 }
1356
1357 fn split_ocr_bboxes_by_table_cells(
1365 tables: &[TableResult],
1366 text_regions: &mut Vec<crate::oarocr::TextRegion>,
1367 page_image: &image::RgbImage,
1368 text_recognition_adapter: &TextRecognitionAdapter,
1369 ) -> Result<(), OCRError> {
1370 use oar_ocr_core::core::traits::task::ImageTaskInput;
1371 use oar_ocr_core::processors::BoundingBox;
1372
1373 let mut cell_boxes: Vec<[f32; 4]> = Vec::new();
1375 for table in tables {
1376 for cell in &table.cells {
1377 let x1 = cell.bbox.x_min();
1378 let y1 = cell.bbox.y_min();
1379 let x2 = cell.bbox.x_max();
1380 let y2 = cell.bbox.y_max();
1381 if x2 > x1 && y2 > y1 {
1382 cell_boxes.push([x1, y1, x2, y2]);
1383 }
1384 }
1385 }
1386
1387 if cell_boxes.is_empty() || text_regions.is_empty() {
1388 return Ok(());
1389 }
1390
1391 fn overlap_ratio_box_over_cell(box1: &[f32; 4], box2: &[f32; 4]) -> f32 {
1393 let x_left = box1[0].max(box2[0]);
1394 let y_top = box1[1].max(box2[1]);
1395 let x_right = box1[2].min(box2[2]);
1396 let y_bottom = box1[3].min(box2[3]);
1397
1398 if x_right <= x_left || y_bottom <= y_top {
1399 return 0.0;
1400 }
1401
1402 let inter_area = (x_right - x_left) * (y_bottom - y_top);
1403 let cell_area = (box2[2] - box2[0]) * (box2[3] - box2[1]);
1404 if cell_area <= 0.0 {
1405 0.0
1406 } else {
1407 inter_area / cell_area
1408 }
1409 }
1410
1411 fn get_overlapping_cells(
1413 ocr_box: &[f32; 4],
1414 cells: &[[f32; 4]],
1415 threshold: f32,
1416 ) -> Vec<usize> {
1417 let mut overlapping = Vec::new();
1418 for (idx, cell) in cells.iter().enumerate() {
1419 if overlap_ratio_box_over_cell(ocr_box, cell) > threshold {
1420 overlapping.push(idx);
1421 }
1422 }
1423 overlapping.sort_by(|&i, &j| {
1425 cells[i][0]
1426 .partial_cmp(&cells[j][0])
1427 .unwrap_or(std::cmp::Ordering::Equal)
1428 });
1429 overlapping
1430 }
1431
1432 fn split_box_by_cells(
1434 ocr_box: &[f32; 4],
1435 cell_indices: &[usize],
1436 cells: &[[f32; 4]],
1437 ) -> Vec<[f32; 4]> {
1438 if cell_indices.is_empty() {
1439 return vec![*ocr_box];
1440 }
1441
1442 let mut split_boxes: Vec<[f32; 4]> = Vec::new();
1443 let cells_to_split: Vec<[f32; 4]> = cell_indices.iter().map(|&i| cells[i]).collect();
1444
1445 if ocr_box[0] < cells_to_split[0][0] {
1447 split_boxes.push([ocr_box[0], ocr_box[1], cells_to_split[0][0], ocr_box[3]]);
1448 }
1449
1450 for (i, current_cell) in cells_to_split.iter().enumerate() {
1452 split_boxes.push([
1454 ocr_box[0].max(current_cell[0]),
1455 ocr_box[1],
1456 ocr_box[2].min(current_cell[2]),
1457 ocr_box[3],
1458 ]);
1459
1460 if i + 1 < cells_to_split.len() {
1462 let next_cell = cells_to_split[i + 1];
1463 if current_cell[2] < next_cell[0] {
1464 split_boxes.push([current_cell[2], ocr_box[1], next_cell[0], ocr_box[3]]);
1465 }
1466 }
1467 }
1468
1469 let last_cell = cells_to_split[cells_to_split.len() - 1];
1471 if last_cell[2] < ocr_box[2] {
1472 split_boxes.push([last_cell[2], ocr_box[1], ocr_box[2], ocr_box[3]]);
1473 }
1474
1475 let mut unique = Vec::new();
1477 let mut seen = std::collections::HashSet::new();
1478 for b in split_boxes {
1479 let key = (
1480 b[0].to_bits(),
1481 b[1].to_bits(),
1482 b[2].to_bits(),
1483 b[3].to_bits(),
1484 );
1485 if seen.insert(key) {
1486 unique.push(b);
1487 }
1488 }
1489 unique
1490 }
1491
1492 let k_min_cells = 2usize;
1493 let overlap_threshold = CELL_OVERLAP_IOU_THRESHOLD;
1494
1495 let mut new_regions: Vec<crate::oarocr::TextRegion> =
1496 Vec::with_capacity(text_regions.len());
1497
1498 for region in text_regions.iter() {
1499 let ocr_box = [
1500 region.bounding_box.x_min(),
1501 region.bounding_box.y_min(),
1502 region.bounding_box.x_max(),
1503 region.bounding_box.y_max(),
1504 ];
1505
1506 let overlapping_cells = get_overlapping_cells(&ocr_box, &cell_boxes, overlap_threshold);
1507
1508 if overlapping_cells.len() < k_min_cells {
1510 new_regions.push(region.clone());
1511 continue;
1512 }
1513
1514 let split_boxes = split_box_by_cells(&ocr_box, &overlapping_cells, &cell_boxes);
1515
1516 for box_coords in split_boxes {
1517 let img_w = page_image.width() as i32;
1519 let img_h = page_image.height() as i32;
1520
1521 let mut x1 = box_coords[0].floor() as i32;
1522 let mut y1 = box_coords[1].floor() as i32;
1523 let mut x2 = box_coords[2].ceil() as i32;
1524 let mut y2 = box_coords[3].ceil() as i32;
1525
1526 x1 = x1.clamp(0, img_w.saturating_sub(1));
1527 y1 = y1.clamp(0, img_h.saturating_sub(1));
1528 x2 = x2.clamp(0, img_w);
1529 y2 = y2.clamp(0, img_h);
1530
1531 if x2 - x1 <= 1 || y2 - y1 <= 1 {
1532 continue;
1533 }
1534
1535 let crop_w = (x2 - x1) as u32;
1536 let crop_h = (y2 - y1) as u32;
1537 if crop_w <= 1 || crop_h <= 1 {
1538 continue;
1539 }
1540
1541 let x1u = x1 as u32;
1542 let y1u = y1 as u32;
1543 if x1u >= page_image.width() || y1u >= page_image.height() {
1544 continue;
1545 }
1546 let crop_w = crop_w.min(page_image.width() - x1u);
1547 let crop_h = crop_h.min(page_image.height() - y1u);
1548 if crop_w <= 1 || crop_h <= 1 {
1549 continue;
1550 }
1551
1552 let crop =
1553 image::imageops::crop_imm(page_image, x1u, y1u, crop_w, crop_h).to_image();
1554
1555 let rec_input = ImageTaskInput::new(vec![crop]);
1556 let rec_result = text_recognition_adapter.execute(rec_input, None)?;
1557 if let (Some(text), Some(score)) =
1558 (rec_result.texts.first(), rec_result.scores.first())
1559 && !text.is_empty()
1560 {
1561 let bbox = BoundingBox::from_coords(
1562 box_coords[0],
1563 box_coords[1],
1564 box_coords[2],
1565 box_coords[3],
1566 );
1567 new_regions.push(crate::oarocr::TextRegion {
1568 bounding_box: bbox.clone(),
1569 dt_poly: Some(bbox.clone()),
1570 rec_poly: Some(bbox),
1571 text: Some(Arc::from(text.as_str())),
1572 confidence: Some(*score),
1573 orientation_angle: None,
1574 word_boxes: None,
1575 });
1576 }
1577 }
1578 }
1579
1580 *text_regions = new_regions;
1581 Ok(())
1582 }
1583
1584 fn detect_layout_and_regions(
1585 &self,
1586 page_image: &image::RgbImage,
1587 ) -> Result<
1588 (
1589 Vec<crate::domain::structure::LayoutElement>,
1590 Option<Vec<crate::domain::structure::RegionBlock>>,
1591 ),
1592 OCRError,
1593 > {
1594 use oar_ocr_core::core::traits::task::ImageTaskInput;
1595 use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType, RegionBlock};
1596
1597 let input = ImageTaskInput::new(vec![page_image.clone()]);
1598 let layout_result = self
1599 .pipeline
1600 .layout_detection_adapter
1601 .execute(input, None)?;
1602
1603 let mut layout_elements: Vec<LayoutElement> = Vec::new();
1604 if let Some(elements) = layout_result.elements.first() {
1605 for element in elements {
1606 let element_type_enum = LayoutElementType::from_label(&element.element_type);
1607 layout_elements.push(
1608 LayoutElement::new(element.bbox.clone(), element_type_enum, element.score)
1609 .with_label(element.element_type.clone()),
1610 );
1611 }
1612 }
1613
1614 let mut detected_region_blocks: Option<Vec<RegionBlock>> = None;
1615 if let Some(ref region_adapter) = self.pipeline.region_detection_adapter {
1616 let region_input = ImageTaskInput::new(vec![page_image.clone()]);
1617 if let Ok(region_result) = region_adapter.execute(region_input, None)
1618 && let Some(region_elements) = region_result.elements.first()
1619 && !region_elements.is_empty()
1620 {
1621 let blocks: Vec<RegionBlock> = region_elements
1622 .iter()
1623 .map(|e| RegionBlock {
1624 bbox: e.bbox.clone(),
1625 confidence: e.score,
1626 order_index: None,
1627 element_indices: Vec::new(),
1628 })
1629 .collect();
1630 detected_region_blocks = Some(blocks);
1631 }
1632 }
1633
1634 if layout_elements.len() > 1 {
1635 let removed = crate::domain::structure::remove_overlapping_layout_elements(
1636 &mut layout_elements,
1637 LAYOUT_OVERLAP_IOU_THRESHOLD,
1638 );
1639 if removed > 0 {
1640 tracing::info!(
1641 "Removing {} overlapping layout elements (threshold={})",
1642 removed,
1643 LAYOUT_OVERLAP_IOU_THRESHOLD
1644 );
1645 }
1646 }
1647
1648 crate::domain::structure::apply_standardized_layout_label_fixes(&mut layout_elements);
1649
1650 Ok((layout_elements, detected_region_blocks))
1651 }
1652
1653 fn recognize_formulas(
1654 &self,
1655 page_image: &image::RgbImage,
1656 layout_elements: &[crate::domain::structure::LayoutElement],
1657 ) -> Result<Vec<crate::domain::structure::FormulaResult>, OCRError> {
1658 use oar_ocr_core::core::traits::task::ImageTaskInput;
1659 use oar_ocr_core::domain::structure::FormulaResult;
1660 use oar_ocr_core::utils::BBoxCrop;
1661
1662 let Some(ref formula_adapter) = self.pipeline.formula_recognition_adapter else {
1663 return Ok(Vec::new());
1664 };
1665
1666 let formula_elements: Vec<_> = layout_elements
1667 .iter()
1668 .filter(|e| e.element_type.is_formula())
1669 .collect();
1670
1671 if formula_elements.is_empty() {
1672 tracing::debug!(
1673 "Formula recognition skipped: no formula regions from layout detection"
1674 );
1675 return Ok(Vec::new());
1676 }
1677
1678 let mut crops = Vec::new();
1679 let mut bboxes = Vec::new();
1680
1681 for elem in &formula_elements {
1682 match BBoxCrop::crop_bounding_box(page_image, &elem.bbox) {
1683 Ok(crop) => {
1684 crops.push(crop);
1685 bboxes.push(elem.bbox.clone());
1686 }
1687 Err(err) => {
1688 tracing::warn!("Formula region crop failed: {}", err);
1689 }
1690 }
1691 }
1692
1693 if crops.is_empty() {
1694 tracing::debug!(
1695 "Formula recognition skipped: all formula crops failed for {} regions",
1696 formula_elements.len()
1697 );
1698 return Ok(Vec::new());
1699 }
1700
1701 let input = ImageTaskInput::new(crops);
1702 let formula_result = formula_adapter.execute(input, None)?;
1703
1704 let mut formulas = Vec::new();
1705 for ((bbox, formula), score) in bboxes
1706 .into_iter()
1707 .zip(formula_result.formulas.into_iter())
1708 .zip(formula_result.scores.into_iter())
1709 {
1710 let width = bbox.x_max() - bbox.x_min();
1711 let height = bbox.y_max() - bbox.y_min();
1712 if width <= 0.0 || height <= 0.0 {
1713 tracing::warn!(
1714 "Skipping formula with non-positive bbox dimensions: w={:.2}, h={:.2}",
1715 width,
1716 height
1717 );
1718 continue;
1719 }
1720
1721 formulas.push(FormulaResult {
1722 bbox,
1723 latex: formula,
1724 confidence: score.unwrap_or(0.0),
1725 });
1726 }
1727
1728 Ok(formulas)
1729 }
1730
1731 fn detect_seal_text(
1732 &self,
1733 page_image: &image::RgbImage,
1734 layout_elements: &mut Vec<crate::domain::structure::LayoutElement>,
1735 ) -> Result<(), OCRError> {
1736 use oar_ocr_core::core::traits::task::ImageTaskInput;
1737 use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType};
1738 use oar_ocr_core::processors::Point;
1739 use oar_ocr_core::utils::BBoxCrop;
1740
1741 let Some(ref seal_adapter) = self.pipeline.seal_text_detection_adapter else {
1742 return Ok(());
1743 };
1744
1745 let seal_regions: Vec<_> = layout_elements
1746 .iter()
1747 .filter(|e| e.element_type == LayoutElementType::Seal)
1748 .map(|e| e.bbox.clone())
1749 .collect();
1750
1751 if seal_regions.is_empty() {
1752 tracing::debug!("Seal detection skipped: no seal regions from layout detection");
1753 return Ok(());
1754 }
1755
1756 let mut seal_crops = Vec::new();
1757 let mut crop_offsets = Vec::new();
1758
1759 for region_bbox in &seal_regions {
1760 match BBoxCrop::crop_bounding_box(page_image, region_bbox) {
1761 Ok(crop) => {
1762 seal_crops.push(crop);
1763 crop_offsets.push((region_bbox.x_min(), region_bbox.y_min()));
1764 }
1765 Err(err) => {
1766 tracing::warn!("Seal region crop failed: {}", err);
1767 }
1768 }
1769 }
1770
1771 if seal_crops.is_empty() {
1772 return Ok(());
1773 }
1774
1775 let input = ImageTaskInput::new(seal_crops);
1776 let seal_result = seal_adapter.execute(input, None)?;
1777
1778 for ((dx, dy), detections) in crop_offsets.iter().zip(seal_result.detections.into_iter()) {
1779 for detection in detections {
1780 let translated_bbox = crate::processors::BoundingBox::new(
1781 detection
1782 .bbox
1783 .points
1784 .iter()
1785 .map(|p| Point::new(p.x + dx, p.y + dy))
1786 .collect(),
1787 );
1788
1789 layout_elements.push(
1790 LayoutElement::new(translated_bbox, LayoutElementType::Seal, detection.score)
1791 .with_label("seal".to_string()),
1792 );
1793 }
1794 }
1795
1796 Ok(())
1797 }
1798
1799 fn sort_layout_elements_enhanced(
1800 layout_elements: &mut Vec<crate::domain::structure::LayoutElement>,
1801 page_width: f32,
1802 page_height: f32,
1803 ) {
1804 use oar_ocr_core::processors::layout_sorting::sort_layout_enhanced;
1805
1806 if layout_elements.is_empty() {
1807 return;
1808 }
1809
1810 let sortable_elements: Vec<_> = layout_elements
1811 .iter()
1812 .map(|e| (e.bbox.clone(), e.element_type))
1813 .collect();
1814
1815 let sorted_indices = sort_layout_enhanced(&sortable_elements, page_width, page_height);
1816 if sorted_indices.len() != layout_elements.len() {
1817 return;
1818 }
1819
1820 let sorted_elements: Vec<_> = sorted_indices
1821 .into_iter()
1822 .map(|idx| layout_elements[idx].clone())
1823 .collect();
1824 *layout_elements = sorted_elements;
1825 }
1826
1827 fn assign_region_block_membership(
1828 region_blocks: &mut [crate::domain::structure::RegionBlock],
1829 layout_elements: &[crate::domain::structure::LayoutElement],
1830 ) {
1831 use std::cmp::Ordering;
1832
1833 if region_blocks.is_empty() {
1834 return;
1835 }
1836
1837 region_blocks.sort_by(|a, b| {
1838 a.bbox
1839 .y_min()
1840 .partial_cmp(&b.bbox.y_min())
1841 .unwrap_or(Ordering::Equal)
1842 .then_with(|| {
1843 a.bbox
1844 .x_min()
1845 .partial_cmp(&b.bbox.x_min())
1846 .unwrap_or(Ordering::Equal)
1847 })
1848 });
1849
1850 for (i, region) in region_blocks.iter_mut().enumerate() {
1851 region.order_index = Some((i + 1) as u32);
1852 region.element_indices.clear();
1853 }
1854
1855 if layout_elements.is_empty() {
1856 return;
1857 }
1858
1859 for (elem_idx, elem) in layout_elements.iter().enumerate() {
1860 let elem_area = elem.bbox.area();
1861 if elem_area <= 0.0 {
1862 continue;
1863 }
1864
1865 let mut best_region: Option<usize> = None;
1866 let mut best_ioa = 0.0f32;
1867
1868 for (region_idx, region) in region_blocks.iter().enumerate() {
1869 let intersection = elem.bbox.intersection_area(®ion.bbox);
1870 if intersection <= 0.0 {
1871 continue;
1872 }
1873 let ioa = intersection / elem_area;
1874 if ioa > best_ioa {
1875 best_ioa = ioa;
1876 best_region = Some(region_idx);
1877 }
1878 }
1879
1880 if let Some(region_idx) = best_region
1881 && best_ioa >= REGION_MEMBERSHIP_IOA_THRESHOLD
1882 {
1883 region_blocks[region_idx].element_indices.push(elem_idx);
1884 }
1885 }
1886 }
1887
1888 fn run_overall_ocr(
1889 &self,
1890 page_image: &image::RgbImage,
1891 layout_elements: &[crate::domain::structure::LayoutElement],
1892 region_blocks: Option<&[crate::domain::structure::RegionBlock]>,
1893 ) -> Result<Vec<crate::oarocr::TextRegion>, OCRError> {
1894 use crate::oarocr::TextRegion;
1895 use oar_ocr_core::core::traits::task::ImageTaskInput;
1896 use std::sync::Arc;
1897
1898 let Some(ref text_detection_adapter) = self.pipeline.text_detection_adapter else {
1899 return Ok(Vec::new());
1900 };
1901 let Some(ref text_recognition_adapter) = self.pipeline.text_recognition_adapter else {
1902 return Ok(Vec::new());
1903 };
1904
1905 let mut text_regions = Vec::new();
1906
1907 let mut ocr_image = page_image.clone();
1909 let mask_bboxes: Vec<crate::processors::BoundingBox> = layout_elements
1910 .iter()
1911 .filter(|e| e.element_type.is_formula())
1912 .map(|e| e.bbox.clone())
1913 .collect();
1914
1915 if !mask_bboxes.is_empty() {
1916 crate::utils::mask_regions(&mut ocr_image, &mask_bboxes, [255, 255, 255]);
1917 }
1918
1919 let input = ImageTaskInput::new(vec![ocr_image.clone()]);
1921 let det_result = text_detection_adapter.execute(input, None)?;
1922
1923 let mut detection_boxes = if let Some(detections) = det_result.detections.first() {
1924 detections
1925 .iter()
1926 .map(|d| d.bbox.clone())
1927 .collect::<Vec<_>>()
1928 } else {
1929 Vec::new()
1930 };
1931
1932 let raw_detection_boxes = detection_boxes.clone();
1934 if tracing::enabled!(tracing::Level::DEBUG) && !raw_detection_boxes.is_empty() {
1935 let raw_rects: Vec<[f32; 4]> = raw_detection_boxes
1936 .iter()
1937 .map(|b| [b.x_min(), b.y_min(), b.x_max(), b.y_max()])
1938 .collect();
1939 tracing::debug!("overall OCR text det boxes (raw): {:?}", raw_rects);
1940 }
1941
1942 if !detection_boxes.is_empty() {
1944 let mut split_boxes = Vec::new();
1945 let mut split_count = 0usize;
1946
1947 let container_boxes: Vec<crate::processors::BoundingBox> =
1948 if let Some(regions) = region_blocks {
1949 regions.iter().map(|r| r.bbox.clone()).collect()
1950 } else {
1951 layout_elements
1952 .iter()
1953 .filter(|e| {
1954 matches!(
1955 e.element_type,
1956 crate::domain::structure::LayoutElementType::DocTitle
1957 | crate::domain::structure::LayoutElementType::ParagraphTitle
1958 | crate::domain::structure::LayoutElementType::Text
1959 | crate::domain::structure::LayoutElementType::Content
1960 | crate::domain::structure::LayoutElementType::Abstract
1961 | crate::domain::structure::LayoutElementType::Header
1962 | crate::domain::structure::LayoutElementType::Footer
1963 | crate::domain::structure::LayoutElementType::Footnote
1964 | crate::domain::structure::LayoutElementType::Number
1965 | crate::domain::structure::LayoutElementType::Reference
1966 | crate::domain::structure::LayoutElementType::ReferenceContent
1967 | crate::domain::structure::LayoutElementType::Algorithm
1968 | crate::domain::structure::LayoutElementType::AsideText
1969 | crate::domain::structure::LayoutElementType::List
1970 | crate::domain::structure::LayoutElementType::FigureTitle
1971 | crate::domain::structure::LayoutElementType::TableTitle
1972 | crate::domain::structure::LayoutElementType::ChartTitle
1973 | crate::domain::structure::LayoutElementType::FigureTableChartTitle
1974 )
1975 })
1976 .map(|e| e.bbox.clone())
1977 .collect()
1978 };
1979
1980 if !container_boxes.is_empty() {
1981 for bbox in detection_boxes.into_iter() {
1982 let mut intersections: Vec<crate::processors::BoundingBox> = Vec::new();
1983 let self_area = bbox.area();
1984 if self_area <= 0.0 {
1985 split_boxes.push(bbox);
1986 continue;
1987 }
1988
1989 for container in &container_boxes {
1990 let inter_x_min = bbox.x_min().max(container.x_min());
1991 let inter_y_min = bbox.y_min().max(container.y_min());
1992 let inter_x_max = bbox.x_max().min(container.x_max());
1993 let inter_y_max = bbox.y_max().min(container.y_max());
1994
1995 if inter_x_max - inter_x_min <= 2.0 || inter_y_max - inter_y_min <= 2.0 {
1996 continue;
1997 }
1998
1999 let inter_bbox = crate::processors::BoundingBox::from_coords(
2000 inter_x_min,
2001 inter_y_min,
2002 inter_x_max,
2003 inter_y_max,
2004 );
2005 let inter_area = inter_bbox.area();
2006 if inter_area <= 0.0 {
2007 continue;
2008 }
2009
2010 let ioa = inter_area / self_area;
2011 if ioa >= TEXT_BOX_SPLIT_IOA_THRESHOLD {
2012 intersections.push(inter_bbox);
2013 }
2014 }
2015
2016 if intersections.len() >= 2 {
2017 split_count += intersections.len();
2018 split_boxes.extend(intersections);
2019 } else {
2020 split_boxes.push(bbox);
2021 }
2022 }
2023
2024 if split_count > 0 {
2025 tracing::debug!(
2026 "Cross-layout re-recognition: split {} text boxes into {} sub-boxes",
2027 split_count,
2028 split_boxes.len()
2029 );
2030 }
2031
2032 detection_boxes = split_boxes;
2033 }
2034 }
2035
2036 if tracing::enabled!(tracing::Level::DEBUG) && !detection_boxes.is_empty() {
2038 let pre_rec_rects: Vec<[f32; 4]> = detection_boxes
2039 .iter()
2040 .map(|b| [b.x_min(), b.y_min(), b.x_max(), b.y_max()])
2041 .collect();
2042 tracing::debug!(
2043 "overall OCR boxes pre-recognition (after splitting): {:?}",
2044 pre_rec_rects
2045 );
2046 }
2047
2048 if !detection_boxes.is_empty() {
2049 use crate::oarocr::processors::{EdgeProcessor, TextCroppingProcessor};
2050
2051 let processor = TextCroppingProcessor::new(true);
2052 let cropped =
2053 processor.process((Arc::new(page_image.clone()), detection_boxes.clone()))?;
2054
2055 let mut cropped_images: Vec<image::RgbImage> = Vec::new();
2056 let mut valid_indices: Vec<usize> = Vec::new();
2057
2058 for (idx, crop_result) in cropped.into_iter().enumerate() {
2059 if let Some(img) = crop_result {
2060 cropped_images.push((*img).clone());
2061 valid_indices.push(idx);
2062 }
2063 }
2064
2065 if !cropped_images.is_empty() {
2066 let mut items: Vec<(usize, f32, image::RgbImage)> = cropped_images
2067 .into_iter()
2068 .zip(valid_indices)
2069 .map(|(img, det_idx)| {
2070 let wh_ratio = img.width() as f32 / img.height().max(1) as f32;
2071 (det_idx, wh_ratio, img)
2072 })
2073 .collect();
2074
2075 items.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
2076
2077 let batch_size = self.pipeline.region_batch_size.unwrap_or(8).max(1);
2078
2079 while !items.is_empty() {
2080 let take_n = batch_size.min(items.len());
2081 let mut batch_items: Vec<(usize, f32, image::RgbImage)> =
2082 items.drain(0..take_n).collect();
2083
2084 if let Some(ref tlo_adapter) = self.pipeline.text_line_orientation_adapter {
2085 let tlo_imgs: Vec<_> =
2086 batch_items.iter().map(|(_, _, img)| img.clone()).collect();
2087 let tlo_input = ImageTaskInput::new(tlo_imgs);
2088 if let Ok(tlo_result) = tlo_adapter.execute(tlo_input, None) {
2089 for (i, classifications) in
2090 tlo_result.classifications.iter().enumerate()
2091 {
2092 if i >= batch_items.len() {
2093 break;
2094 }
2095 if let Some(top_cls) = classifications.first()
2096 && top_cls.class_id == 1
2097 {
2098 batch_items[i].2 =
2099 image::imageops::rotate180(&batch_items[i].2);
2100 }
2101 }
2102 }
2103 }
2104
2105 let mut det_indices: Vec<usize> = Vec::with_capacity(batch_items.len());
2106 let mut rec_imgs: Vec<image::RgbImage> = Vec::with_capacity(batch_items.len());
2107 for (det_idx, _ratio, img) in batch_items {
2108 det_indices.push(det_idx);
2109 rec_imgs.push(img);
2110 }
2111
2112 let rec_input = ImageTaskInput::new(rec_imgs);
2113 if let Ok(rec_result) = text_recognition_adapter.execute(rec_input, None) {
2114 for ((det_idx, text), score) in det_indices
2115 .into_iter()
2116 .zip(rec_result.texts.into_iter())
2117 .zip(rec_result.scores.into_iter())
2118 {
2119 if text.is_empty() {
2120 continue;
2121 }
2122
2123 let bbox = detection_boxes[det_idx].clone();
2124 text_regions.push(TextRegion {
2125 bounding_box: bbox.clone(),
2126 dt_poly: Some(bbox.clone()),
2127 rec_poly: Some(bbox),
2128 text: Some(Arc::from(text)),
2129 confidence: Some(score),
2130 orientation_angle: None,
2131 word_boxes: None,
2132 });
2133 }
2134 }
2135 }
2136 }
2137 }
2138
2139 let batch_size = self.pipeline.region_batch_size.unwrap_or(8).max(1);
2140 Self::refine_overall_ocr_with_layout(
2141 &mut text_regions,
2142 layout_elements,
2143 region_blocks,
2144 page_image,
2145 text_recognition_adapter,
2146 batch_size,
2147 )?;
2148
2149 Ok(text_regions)
2150 }
2151
2152 pub fn predict(&self, image_path: impl Into<PathBuf>) -> Result<StructureResult, OCRError> {
2162 let image_path = image_path.into();
2163
2164 let image = image::open(&image_path).map_err(|e| OCRError::InvalidInput {
2166 message: format!(
2167 "failed to load image from '{}': {}",
2168 image_path.display(),
2169 e
2170 ),
2171 })?;
2172
2173 let mut result = self.predict_image(image.to_rgb8())?;
2174 result.input_path = std::sync::Arc::from(image_path.to_string_lossy().as_ref());
2175 Ok(result)
2176 }
2177
2178 pub fn predict_image(&self, image: image::RgbImage) -> Result<StructureResult, OCRError> {
2191 use crate::oarocr::preprocess::DocumentPreprocessor;
2192 use std::sync::Arc;
2193
2194 let preprocessor = DocumentPreprocessor::new(
2195 self.pipeline.document_orientation_adapter.as_ref(),
2196 self.pipeline.rectification_adapter.as_ref(),
2197 );
2198 let preprocess = preprocessor.preprocess(Arc::new(image))?;
2199 let current_image = preprocess.image;
2200 let orientation_angle = preprocess.orientation_angle;
2201 let rectified_img = preprocess.rectified_img;
2202 let rotation = preprocess.rotation;
2203
2204 let (mut layout_elements, mut detected_region_blocks) =
2205 self.detect_layout_and_regions(¤t_image)?;
2206
2207 let mut tables = Vec::new();
2208 let mut formulas = self.recognize_formulas(¤t_image, &layout_elements)?;
2209
2210 self.detect_seal_text(¤t_image, &mut layout_elements)?;
2211
2212 if !layout_elements.is_empty() {
2215 let (width, height) = if let Some(img) = &rectified_img {
2216 (img.width() as f32, img.height() as f32)
2217 } else {
2218 (current_image.width() as f32, current_image.height() as f32)
2219 };
2220 Self::sort_layout_elements_enhanced(&mut layout_elements, width, height);
2221 }
2222
2223 if let Some(ref mut regions) = detected_region_blocks {
2224 Self::assign_region_block_membership(regions, &layout_elements);
2225 }
2226
2227 let mut text_regions = self.run_overall_ocr(
2228 ¤t_image,
2229 &layout_elements,
2230 detected_region_blocks.as_deref(),
2231 )?;
2232
2233 {
2234 let analyzer = crate::oarocr::table_analyzer::TableAnalyzer::new(
2235 crate::oarocr::table_analyzer::TableAnalyzerConfig {
2236 table_classification_adapter: self
2237 .pipeline
2238 .table_classification_adapter
2239 .as_ref(),
2240 table_orientation_adapter: self.pipeline.table_orientation_adapter.as_ref(),
2241 table_structure_recognition_adapter: self
2242 .pipeline
2243 .table_structure_recognition_adapter
2244 .as_ref(),
2245 wired_table_structure_adapter: self
2246 .pipeline
2247 .wired_table_structure_adapter
2248 .as_ref(),
2249 wireless_table_structure_adapter: self
2250 .pipeline
2251 .wireless_table_structure_adapter
2252 .as_ref(),
2253 table_cell_detection_adapter: self
2254 .pipeline
2255 .table_cell_detection_adapter
2256 .as_ref(),
2257 wired_table_cell_adapter: self.pipeline.wired_table_cell_adapter.as_ref(),
2258 wireless_table_cell_adapter: self.pipeline.wireless_table_cell_adapter.as_ref(),
2259 use_e2e_wired_table_rec: self.pipeline.use_e2e_wired_table_rec,
2260 use_e2e_wireless_table_rec: self.pipeline.use_e2e_wireless_table_rec,
2261 },
2262 );
2263 tables.extend(analyzer.analyze_tables(
2264 ¤t_image,
2265 &layout_elements,
2266 &formulas,
2267 &text_regions,
2268 )?);
2269 }
2270
2271 if !tables.is_empty()
2279 && !text_regions.is_empty()
2280 && let Some(ref text_rec_adapter) = self.pipeline.text_recognition_adapter
2281 {
2282 Self::split_ocr_bboxes_by_table_cells(
2283 &tables,
2284 &mut text_regions,
2285 ¤t_image,
2286 text_rec_adapter,
2287 )?;
2288 }
2289
2290 if let Some(rot) = rotation {
2293 let rotated_width = rot.rotated_width;
2294 let rotated_height = rot.rotated_height;
2295 let angle = rot.angle;
2296
2297 for element in &mut layout_elements {
2299 element.bbox =
2300 element
2301 .bbox
2302 .rotate_back_to_original(angle, rotated_width, rotated_height);
2303 }
2304
2305 for table in &mut tables {
2307 table.bbox =
2308 table
2309 .bbox
2310 .rotate_back_to_original(angle, rotated_width, rotated_height);
2311
2312 for cell in &mut table.cells {
2314 cell.bbox =
2315 cell.bbox
2316 .rotate_back_to_original(angle, rotated_width, rotated_height);
2317 }
2318 }
2319
2320 for formula in &mut formulas {
2322 formula.bbox =
2323 formula
2324 .bbox
2325 .rotate_back_to_original(angle, rotated_width, rotated_height);
2326 }
2327
2328 for region in &mut text_regions {
2330 region.dt_poly = region
2331 .dt_poly
2332 .take()
2333 .map(|poly| poly.rotate_back_to_original(angle, rotated_width, rotated_height));
2334 region.rec_poly = region
2335 .rec_poly
2336 .take()
2337 .map(|poly| poly.rotate_back_to_original(angle, rotated_width, rotated_height));
2338 region.bounding_box = region.bounding_box.rotate_back_to_original(
2339 angle,
2340 rotated_width,
2341 rotated_height,
2342 );
2343
2344 if let Some(ref word_boxes) = region.word_boxes {
2345 let transformed_word_boxes: Vec<_> = word_boxes
2346 .iter()
2347 .map(|wb| wb.rotate_back_to_original(angle, rotated_width, rotated_height))
2348 .collect();
2349 region.word_boxes = Some(transformed_word_boxes);
2350 }
2351 }
2352
2353 if let Some(ref mut regions) = detected_region_blocks {
2355 for region in regions.iter_mut() {
2356 region.bbox =
2357 region
2358 .bbox
2359 .rotate_back_to_original(angle, rotated_width, rotated_height);
2360 }
2361 }
2362 }
2363
2364 let final_image = rectified_img.unwrap_or_else(|| Arc::new((*current_image).clone()));
2368 let mut result = StructureResult {
2369 input_path: Arc::from("memory"),
2370 index: 0,
2371 layout_elements,
2372 tables,
2373 formulas,
2374 text_regions: if text_regions.is_empty() {
2375 None
2376 } else {
2377 Some(text_regions)
2378 },
2379 orientation_angle,
2380 region_blocks: detected_region_blocks,
2381 rectified_img: Some(final_image),
2382 page_continuation_flags: None,
2383 };
2384
2385 use crate::oarocr::stitching::{ResultStitcher, StitchConfig};
2388 let stitch_cfg = StitchConfig::default();
2389 ResultStitcher::stitch_with_config(&mut result, &stitch_cfg);
2390
2391 Ok(result)
2392 }
2393}
2394
2395#[cfg(test)]
2396mod tests {
2397 use super::*;
2398
2399 #[test]
2400 fn test_structure_builder_new() {
2401 let builder = OARStructureBuilder::new("models/layout.onnx");
2402 assert_eq!(
2403 builder.layout_detection_model,
2404 PathBuf::from("models/layout.onnx")
2405 );
2406 assert!(builder.table_classification_model.is_none());
2407 assert!(builder.formula_recognition_model.is_none());
2408 }
2409
2410 #[test]
2411 fn test_structure_builder_with_table_components() {
2412 let builder = OARStructureBuilder::new("models/layout.onnx")
2413 .with_table_classification("models/table_cls.onnx")
2414 .with_table_cell_detection("models/table_cell.onnx", "wired")
2415 .with_table_structure_recognition("models/table_struct.onnx", "wired")
2416 .table_structure_dict_path("models/table_structure_dict.txt");
2417
2418 assert!(builder.table_classification_model.is_some());
2419 assert!(builder.table_cell_detection_model.is_some());
2420 assert!(builder.table_structure_recognition_model.is_some());
2421 assert_eq!(builder.table_cell_detection_type, Some("wired".to_string()));
2422 assert_eq!(
2423 builder.table_structure_recognition_type,
2424 Some("wired".to_string())
2425 );
2426 assert_eq!(
2427 builder.table_structure_dict_path,
2428 Some(PathBuf::from("models/table_structure_dict.txt"))
2429 );
2430 }
2431
2432 #[test]
2433 fn test_structure_builder_with_formula() {
2434 let builder = OARStructureBuilder::new("models/layout.onnx").with_formula_recognition(
2435 "models/formula.onnx",
2436 "models/tokenizer.json",
2437 "pp_formulanet",
2438 );
2439
2440 assert!(builder.formula_recognition_model.is_some());
2441 assert!(builder.formula_tokenizer_path.is_some());
2442 assert_eq!(
2443 builder.formula_recognition_type,
2444 Some("pp_formulanet".to_string())
2445 );
2446 }
2447
2448 #[test]
2449 fn test_structure_builder_with_ocr() {
2450 let builder = OARStructureBuilder::new("models/layout.onnx").with_ocr(
2451 "models/det.onnx",
2452 "models/rec.onnx",
2453 "models/dict.txt",
2454 );
2455
2456 assert!(builder.text_detection_model.is_some());
2457 assert!(builder.text_recognition_model.is_some());
2458 assert!(builder.character_dict_path.is_some());
2459 }
2460
2461 #[test]
2462 fn test_structure_builder_with_configuration() {
2463 let layout_config = LayoutDetectionConfig {
2464 score_threshold: 0.5,
2465 max_elements: 100,
2466 ..Default::default()
2467 };
2468
2469 let builder = OARStructureBuilder::new("models/layout.onnx")
2470 .layout_detection_config(layout_config.clone())
2471 .image_batch_size(4)
2472 .region_batch_size(64);
2473
2474 assert!(builder.layout_detection_config.is_some());
2475 assert_eq!(builder.image_batch_size, Some(4));
2476 assert_eq!(builder.region_batch_size, Some(64));
2477 }
2478}