oar_ocr_core/predictors/
table_classification.rs1use super::builder::PredictorBuilderState;
6use crate::TaskPredictorBuilder;
7use crate::core::OcrResult;
8use crate::core::traits::OrtConfigurable;
9use crate::core::traits::adapter::AdapterBuilder;
10use crate::core::traits::task::ImageTaskInput;
11use crate::domain::adapters::TableClassificationAdapterBuilder;
12use crate::domain::tasks::document_orientation::Classification;
13use crate::domain::tasks::table_classification::{
14 TableClassificationConfig, TableClassificationTask,
15};
16use crate::predictors::TaskPredictorCore;
17use image::RgbImage;
18use std::path::Path;
19
20#[derive(Debug, Clone)]
22pub struct TableClassificationResult {
23 pub classifications: Vec<Vec<Classification>>,
25}
26
27pub struct TableClassificationPredictor {
29 core: TaskPredictorCore<TableClassificationTask>,
30}
31
32impl TableClassificationPredictor {
33 pub fn builder() -> TableClassificationPredictorBuilder {
35 TableClassificationPredictorBuilder::new()
36 }
37
38 pub fn predict(&self, images: Vec<RgbImage>) -> OcrResult<TableClassificationResult> {
40 let input = ImageTaskInput::new(images);
41 let output = self.core.predict(input)?;
42 Ok(TableClassificationResult {
43 classifications: output.classifications,
44 })
45 }
46}
47
48#[derive(TaskPredictorBuilder)]
50#[builder(config = TableClassificationConfig)]
51pub struct TableClassificationPredictorBuilder {
52 state: PredictorBuilderState<TableClassificationConfig>,
53 input_shape: (u32, u32),
54}
55
56impl TableClassificationPredictorBuilder {
57 pub fn new() -> Self {
59 Self {
60 state: PredictorBuilderState::new(TableClassificationConfig {
61 score_threshold: 0.5,
62 topk: 2,
63 }),
64 input_shape: (224, 224),
65 }
66 }
67
68 pub fn score_threshold(mut self, threshold: f32) -> Self {
70 self.state.config_mut().score_threshold = threshold;
71 self
72 }
73
74 pub fn topk(mut self, k: usize) -> Self {
76 self.state.config_mut().topk = k;
77 self
78 }
79
80 pub fn input_shape(mut self, shape: (u32, u32)) -> Self {
82 self.input_shape = shape;
83 self
84 }
85
86 pub fn build<P: AsRef<Path>>(self, model_path: P) -> OcrResult<TableClassificationPredictor> {
88 let Self { state, input_shape } = self;
89 let (config, ort_config) = state.into_parts();
90 let mut adapter_builder = TableClassificationAdapterBuilder::new()
91 .with_config(config.clone())
92 .input_shape(input_shape);
93
94 if let Some(ort_cfg) = ort_config {
95 adapter_builder = adapter_builder.with_ort_config(ort_cfg);
96 }
97
98 let adapter = Box::new(adapter_builder.build(model_path.as_ref())?);
99 let task = TableClassificationTask::new(config.clone());
100
101 Ok(TableClassificationPredictor {
102 core: TaskPredictorCore::new(adapter, task, config),
103 })
104 }
105}
106
107impl Default for TableClassificationPredictorBuilder {
108 fn default() -> Self {
109 Self::new()
110 }
111}