use crate::core::OCRError;
use crate::core::inference::OrtInfer;
use crate::core::traits::adapter::{AdapterBuilder, AdapterInfo, ModelAdapter};
use crate::core::traits::task::{Task, TaskType};
use crate::domain::tasks::{
TableCellDetection, TableCellDetectionConfig, TableCellDetectionOutput, TableCellDetectionTask,
};
use crate::impl_adapter_builder;
use crate::models::detection::{RTDetrModel, RTDetrModelBuilder, RTDetrPostprocessConfig};
use crate::processors::{ImageScaleInfo, LayoutPostProcess};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct TableCellModelConfig {
pub model_name: String,
pub num_classes: usize,
pub class_labels: HashMap<usize, String>,
pub model_type: String,
pub input_size: Option<(u32, u32)>,
}
impl TableCellModelConfig {
pub fn rtdetr_l_wired_table_cell_det() -> Self {
let mut class_labels = HashMap::new();
class_labels.insert(0, "cell".to_string());
Self {
model_name: "rt-detr-l_wired_table_cell_det".to_string(),
num_classes: 1,
class_labels,
model_type: "rtdetr".to_string(),
input_size: Some((640, 640)),
}
}
pub fn rtdetr_l_wireless_table_cell_det() -> Self {
let mut class_labels = HashMap::new();
class_labels.insert(0, "cell".to_string());
Self {
model_name: "rt-detr-l_wireless_table_cell_det".to_string(),
num_classes: 1,
class_labels,
model_type: "rtdetr".to_string(),
input_size: Some((640, 640)),
}
}
}
#[derive(Debug)]
enum TableCellModel {
RTDetr(RTDetrModel),
}
#[derive(Debug)]
pub struct TableCellDetectionAdapter {
model: TableCellModel,
postprocessor: LayoutPostProcess,
model_config: TableCellModelConfig,
info: AdapterInfo,
config: TableCellDetectionConfig,
}
impl TableCellDetectionAdapter {
fn new_rtdetr(
model: RTDetrModel,
postprocessor: LayoutPostProcess,
model_config: TableCellModelConfig,
info: AdapterInfo,
config: TableCellDetectionConfig,
) -> Self {
Self {
model: TableCellModel::RTDetr(model),
postprocessor,
model_config,
info,
config,
}
}
fn postprocess(
&self,
predictions: &ndarray::Array4<f32>,
img_shapes: Vec<ImageScaleInfo>,
config: &TableCellDetectionConfig,
) -> TableCellDetectionOutput {
let (boxes, class_ids, scores) = self.postprocessor.apply(predictions, img_shapes);
let mut all_cells = Vec::with_capacity(boxes.len());
for (img_boxes, (img_classes, img_scores)) in boxes
.into_iter()
.zip(class_ids.into_iter().zip(scores.into_iter()))
{
let mut cells = Vec::new();
for (bbox, (class_id, score)) in img_boxes
.into_iter()
.zip(img_classes.into_iter().zip(img_scores.into_iter()))
{
if score < config.score_threshold {
continue;
}
let label = self
.model_config
.class_labels
.get(&class_id)
.cloned()
.unwrap_or_else(|| "cell".to_string());
cells.push(TableCellDetection { bbox, score, label });
if cells.len() >= config.max_cells {
break;
}
}
all_cells.push(cells);
}
TableCellDetectionOutput { cells: all_cells }
}
}
impl ModelAdapter for TableCellDetectionAdapter {
type Task = TableCellDetectionTask;
fn info(&self) -> AdapterInfo {
self.info.clone()
}
fn execute(
&self,
input: <Self::Task as Task>::Input,
config: Option<&<Self::Task as Task>::Config>,
) -> Result<<Self::Task as Task>::Output, OCRError> {
let effective_config = config.unwrap_or(&self.config);
let batch_len = input.images.len();
let (predictions, img_shapes) = match &self.model {
TableCellModel::RTDetr(model) => {
let postprocess_config = RTDetrPostprocessConfig {
num_classes: self.model_config.num_classes,
};
let (output, img_shapes) = model
.forward(input.images, &postprocess_config)
.map_err(|e| {
OCRError::adapter_execution_error(
"TableCellDetectionAdapter",
format!("RTDetr forward (batch_size={})", batch_len),
e,
)
})?;
(output.predictions, img_shapes)
}
};
Ok(self.postprocess(&predictions, img_shapes, effective_config))
}
fn supports_batching(&self) -> bool {
true
}
fn recommended_batch_size(&self) -> usize {
4
}
}
impl_adapter_builder! {
builder_name: TableCellDetectionAdapterBuilder,
adapter_name: TableCellDetectionAdapter,
config_type: TableCellDetectionConfig,
adapter_type: "table_cell_detection",
adapter_desc: "Detects table cell boundaries in table images",
task_type: TableCellDetection,
fields: {
model_config: Option<TableCellModelConfig> = None,
},
methods: {
pub fn model_config(mut self, config: TableCellModelConfig) -> Self {
self.model_config = Some(config);
self
}
pub fn score_threshold(mut self, threshold: f32) -> Self {
self.config.task_config.score_threshold = threshold;
self
}
pub fn max_cells(mut self, max: usize) -> Self {
self.config.task_config.max_cells = max;
self
}
}
build: |builder: TableCellDetectionAdapterBuilder, model_path: &Path| -> Result<TableCellDetectionAdapter, OCRError> {
let model_config = builder.model_config.ok_or_else(|| OCRError::InvalidInput {
message: "Table cell model configuration is required".to_string(),
})?;
let (task_config, ort_config) = builder.config
.into_validated_parts()
.map_err(|err| OCRError::ConfigError {
message: err.to_string(),
})?;
TableCellDetectionAdapterBuilder::build_with_config(model_path, model_config, task_config, ort_config)
},
}
impl TableCellDetectionAdapterBuilder {
fn build_with_config(
model_path: &Path,
model_config: TableCellModelConfig,
task_config: TableCellDetectionConfig,
ort_config: Option<crate::core::config::OrtSessionConfig>,
) -> Result<TableCellDetectionAdapter, OCRError> {
let inference = if ort_config.is_some() {
use crate::core::config::ModelInferenceConfig;
let common_config = ModelInferenceConfig {
ort_session: ort_config,
..Default::default()
};
OrtInfer::from_config(&common_config, model_path, None)?
} else {
OrtInfer::new(model_path, None)?
};
let postprocessor = LayoutPostProcess::new(
model_config.num_classes,
task_config.score_threshold,
0.5,
task_config.max_cells,
model_config.model_type.clone(),
);
let info = AdapterInfo::new(
format!("TableCellDetection_{}", model_config.model_name),
TaskType::TableCellDetection,
format!(
"Table cell detection adapter for {} with {} classes",
model_config.model_name, model_config.num_classes
),
);
let model = match model_config.model_type.as_str() {
"rtdetr" => {
let mut builder = RTDetrModelBuilder::new();
if let Some((height, width)) = model_config.input_size {
builder = builder.image_shape(height, width);
}
builder.build(inference)?
}
other => {
return Err(OCRError::InvalidInput {
message: format!(
"Unsupported model type '{}' for table cell detection. Supported type: rtdetr",
other
),
});
}
};
Ok(TableCellDetectionAdapter::new_rtdetr(
model,
postprocessor,
model_config,
info,
task_config,
))
}
}
#[derive(Debug)]
pub struct RTDetrTableCellAdapterBuilder {
inner: TableCellDetectionAdapterBuilder,
}
impl Default for RTDetrTableCellAdapterBuilder {
fn default() -> Self {
Self {
inner: TableCellDetectionAdapterBuilder::new()
.model_config(TableCellModelConfig::rtdetr_l_wired_table_cell_det()),
}
}
}
impl RTDetrTableCellAdapterBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn wireless() -> Self {
Self {
inner: TableCellDetectionAdapterBuilder::new()
.model_config(TableCellModelConfig::rtdetr_l_wireless_table_cell_det()),
}
}
pub fn score_threshold(mut self, threshold: f32) -> Self {
self.inner = self.inner.score_threshold(threshold);
self
}
pub fn max_cells(mut self, max: usize) -> Self {
self.inner = self.inner.max_cells(max);
self
}
}
impl crate::core::traits::OrtConfigurable for RTDetrTableCellAdapterBuilder {
fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
self.inner = self.inner.with_ort_config(config);
self
}
}
impl AdapterBuilder for RTDetrTableCellAdapterBuilder {
type Config = TableCellDetectionConfig;
type Adapter = TableCellDetectionAdapter;
fn build(self, model_path: &Path) -> Result<Self::Adapter, OCRError> {
self.inner.build(model_path)
}
fn with_config(mut self, config: Self::Config) -> Self {
self.inner = self.inner.with_config(config);
self
}
fn adapter_type(&self) -> &str {
"RTDetrTableCell"
}
}