Skip to main content

oar_ocr_core/predictors/
table_classification.rs

1//! Table Classification Predictor
2//!
3//! This module provides a high-level API for table classification (wired vs wireless tables).
4
5use super::builder::PredictorBuilderState;
6use crate::TaskPredictorBuilder;
7use crate::core::OcrResult;
8use crate::core::traits::OrtConfigurable;
9use crate::core::traits::adapter::AdapterBuilder;
10use crate::core::traits::task::ImageTaskInput;
11use crate::domain::adapters::TableClassificationAdapterBuilder;
12use crate::domain::tasks::document_orientation::Classification;
13use crate::domain::tasks::table_classification::{
14    TableClassificationConfig, TableClassificationTask,
15};
16use crate::predictors::TaskPredictorCore;
17use image::RgbImage;
18use std::path::Path;
19
20/// Table classification prediction result
21#[derive(Debug, Clone)]
22pub struct TableClassificationResult {
23    /// Classification results for each input image
24    pub classifications: Vec<Vec<Classification>>,
25}
26
27/// Table classification predictor
28pub struct TableClassificationPredictor {
29    core: TaskPredictorCore<TableClassificationTask>,
30}
31
32impl TableClassificationPredictor {
33    /// Create a new builder for the table classification predictor
34    pub fn builder() -> TableClassificationPredictorBuilder {
35        TableClassificationPredictorBuilder::new()
36    }
37
38    /// Predict table classifications in the given images.
39    pub fn predict(&self, images: Vec<RgbImage>) -> OcrResult<TableClassificationResult> {
40        let input = ImageTaskInput::new(images);
41        let output = self.core.predict(input)?;
42        Ok(TableClassificationResult {
43            classifications: output.classifications,
44        })
45    }
46}
47
48/// Builder for table classification predictor
49#[derive(TaskPredictorBuilder)]
50#[builder(config = TableClassificationConfig)]
51pub struct TableClassificationPredictorBuilder {
52    state: PredictorBuilderState<TableClassificationConfig>,
53    input_shape: (u32, u32),
54}
55
56impl TableClassificationPredictorBuilder {
57    /// Create a new builder with default configuration
58    pub fn new() -> Self {
59        Self {
60            state: PredictorBuilderState::new(TableClassificationConfig {
61                score_threshold: 0.5,
62                topk: 2,
63            }),
64            input_shape: (224, 224),
65        }
66    }
67
68    /// Set the score threshold
69    pub fn score_threshold(mut self, threshold: f32) -> Self {
70        self.state.config_mut().score_threshold = threshold;
71        self
72    }
73
74    /// Set the top-k predictions to return
75    pub fn topk(mut self, k: usize) -> Self {
76        self.state.config_mut().topk = k;
77        self
78    }
79
80    /// Set the model input shape (height, width)
81    pub fn input_shape(mut self, shape: (u32, u32)) -> Self {
82        self.input_shape = shape;
83        self
84    }
85
86    /// Build the table classification predictor
87    pub fn build<P: AsRef<Path>>(self, model_path: P) -> OcrResult<TableClassificationPredictor> {
88        let Self { state, input_shape } = self;
89        let (config, ort_config) = state.into_parts();
90        let mut adapter_builder = TableClassificationAdapterBuilder::new()
91            .with_config(config.clone())
92            .input_shape(input_shape);
93
94        if let Some(ort_cfg) = ort_config {
95            adapter_builder = adapter_builder.with_ort_config(ort_cfg);
96        }
97
98        let adapter = Box::new(adapter_builder.build(model_path.as_ref())?);
99        let task = TableClassificationTask::new(config.clone());
100
101        Ok(TableClassificationPredictor {
102            core: TaskPredictorCore::new(adapter, task, config),
103        })
104    }
105}
106
107impl Default for TableClassificationPredictorBuilder {
108    fn default() -> Self {
109        Self::new()
110    }
111}