use crate::apply_ort_config;
use crate::core::OCRError;
use crate::core::traits::{
adapter::{AdapterInfo, ModelAdapter},
task::Task,
};
use crate::domain::tasks::{TableStructureRecognitionConfig, TableStructureRecognitionTask};
use crate::impl_adapter_builder;
use crate::models::recognition::{SLANetModel, SLANetModelBuilder};
use crate::processors::TableStructureDecode;
use std::path::Path;
#[derive(Debug)]
pub struct TableStructureRecognitionAdapter {
model: SLANetModel,
decoder: TableStructureDecode,
info: AdapterInfo,
config: TableStructureRecognitionConfig,
}
impl TableStructureRecognitionAdapter {
pub fn new(
model: SLANetModel,
decoder: TableStructureDecode,
info: AdapterInfo,
config: TableStructureRecognitionConfig,
) -> Self {
Self {
model,
decoder,
info,
config,
}
}
pub const DEFAULT_INPUT_SHAPE: (u32, u32) = (512, 512);
pub const DEFAULT_WIRELESS_INPUT_SHAPE: (u32, u32) = (488, 488);
}
impl ModelAdapter for TableStructureRecognitionAdapter {
type Task = TableStructureRecognitionTask;
fn info(&self) -> AdapterInfo {
self.info.clone()
}
fn execute(
&self,
input: <Self::Task as Task>::Input,
config: Option<&<Self::Task as Task>::Config>,
) -> Result<<Self::Task as Task>::Output, OCRError> {
let effective_config = config.unwrap_or(&self.config);
if input.images.is_empty() {
return Err(OCRError::InvalidInput {
message: "No images provided".to_string(),
});
}
let num_images = input.images.len();
tracing::debug!("Processing {} table images", num_images);
let model_output = self.model.forward(input.images).map_err(|e| {
OCRError::adapter_execution_error(
"TableStructureRecognitionAdapter",
format!("model forward (batch_size={})", num_images),
e,
)
})?;
let decode_output = self
.decoder
.decode(
&model_output.structure_logits,
&model_output.bbox_preds,
&model_output.shape_info,
)
.map_err(|e| {
OCRError::adapter_execution_error("TableStructureRecognitionAdapter", "decode", e)
})?;
let mut structures = Vec::with_capacity(num_images);
let mut bboxes = Vec::with_capacity(num_images);
let mut structure_scores = Vec::with_capacity(num_images);
for img_idx in 0..num_images {
let structure_tokens =
decode_output.structure_tokens.get(img_idx).ok_or_else(|| {
OCRError::InvalidInput {
message: format!("No structure tokens decoded for image {}", img_idx),
}
})?;
let image_bboxes =
decode_output
.bboxes
.get(img_idx)
.ok_or_else(|| OCRError::InvalidInput {
message: format!("No bboxes decoded for image {}", img_idx),
})?;
let structure_score = decode_output
.structure_scores
.get(img_idx)
.copied()
.unwrap_or(0.0);
if structure_score < effective_config.score_threshold {
tracing::warn!(
"Image {}: Structure score {:.3} below threshold {:.3}, keeping result",
img_idx,
structure_score,
effective_config.score_threshold
);
}
let trimmed_tokens: Vec<String> = structure_tokens
.iter()
.take(effective_config.max_structure_length)
.cloned()
.collect();
let trimmed_len = trimmed_tokens.len();
if trimmed_len < structure_tokens.len() {
tracing::warn!(
"Image {}: Structure tokens {} exceed max {}, truncating output",
img_idx,
structure_tokens.len(),
effective_config.max_structure_length
);
}
let structure = trimmed_tokens;
tracing::debug!("Image {}: Final structure tokens: {:?}", img_idx, structure);
let bbox: Vec<Vec<f32>> = image_bboxes
.iter()
.take(trimmed_len)
.map(|&bbox_coords| {
let coords: Vec<f32> = bbox_coords.to_vec();
tracing::debug!("Image {}: BBox coords: {:?}", img_idx, coords);
coords
})
.collect();
if bbox.len() < trimmed_len {
tracing::debug!(
"Image {}: {} bounding boxes for {} structure tokens (TD tokens only have bboxes)",
img_idx,
bbox.len(),
trimmed_len
);
}
tracing::debug!("Image {}: Final bbox output: {:?}", img_idx, bbox);
structures.push(structure);
bboxes.push(bbox);
structure_scores.push(structure_score);
}
Ok(crate::domain::tasks::TableStructureRecognitionOutput {
structures,
bboxes,
structure_scores,
})
}
fn supports_batching(&self) -> bool {
true
}
fn recommended_batch_size(&self) -> usize {
8
}
}
impl_adapter_builder! {
builder_name: SLANetWiredAdapterBuilder,
adapter_name: TableStructureRecognitionAdapter,
config_type: TableStructureRecognitionConfig,
adapter_type: "table_structure_recognition_wired",
adapter_desc: "Recognizes table structure for wired tables as HTML tokens",
task_type: TableStructureRecognition,
fields: {
input_shape: Option<(u32, u32)> = Some((512, 512)),
dict_path: Option<std::path::PathBuf> = None,
model_name_override: Option<String> = None,
},
methods: {
pub fn input_shape(mut self, input_shape: (u32, u32)) -> Self {
self.input_shape = Some(input_shape);
self
}
pub fn dict_path(mut self, path: impl Into<std::path::PathBuf>) -> Self {
self.dict_path = Some(path.into());
self
}
pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
self.model_name_override = Some(model_name.into());
self
}
}
build: |builder: SLANetWiredAdapterBuilder, model_path: &Path| -> Result<TableStructureRecognitionAdapter, OCRError> {
let (task_config, ort_config) = builder.config
.into_validated_parts()
.map_err(|err| OCRError::ConfigError {
message: err.to_string(),
})?;
let mut model_builder = SLANetModelBuilder::new();
if let Some(input_shape) = builder.input_shape {
model_builder = model_builder.input_size(input_shape);
}
let model = apply_ort_config!(model_builder, ort_config).build(model_path)?;
let dict_path = builder.dict_path.ok_or_else(|| OCRError::ConfigError {
message: "Dictionary path is required. Use .dict_path() to specify the path to table_structure_dict_ch.txt".to_string(),
})?;
let decoder = TableStructureDecode::from_dict_path(&dict_path)?;
let mut info = SLANetWiredAdapterBuilder::base_adapter_info();
if let Some(model_name) = builder.model_name_override {
info.model_name = model_name;
}
Ok(TableStructureRecognitionAdapter::new(
model,
decoder,
info,
task_config,
))
},
}
impl_adapter_builder! {
builder_name: SLANetWirelessAdapterBuilder,
adapter_name: TableStructureRecognitionAdapter,
config_type: TableStructureRecognitionConfig,
adapter_type: "table_structure_recognition_wireless",
adapter_desc: "Recognizes table structure for wireless tables as HTML tokens",
task_type: TableStructureRecognition,
fields: {
input_shape: Option<(u32, u32)> = Some((488, 488)),
dict_path: Option<std::path::PathBuf> = None,
model_name_override: Option<String> = None,
},
methods: {
pub fn input_shape(mut self, input_shape: (u32, u32)) -> Self {
self.input_shape = Some(input_shape);
self
}
pub fn dict_path(mut self, path: impl Into<std::path::PathBuf>) -> Self {
self.dict_path = Some(path.into());
self
}
pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
self.model_name_override = Some(model_name.into());
self
}
}
build: |builder: SLANetWirelessAdapterBuilder, model_path: &Path| -> Result<TableStructureRecognitionAdapter, OCRError> {
let (task_config, ort_config) = builder.config
.into_validated_parts()
.map_err(|err| OCRError::ConfigError {
message: err.to_string(),
})?;
let mut model_builder = SLANetModelBuilder::new();
if let Some(input_shape) = builder.input_shape {
model_builder = model_builder.input_size(input_shape);
}
let model = apply_ort_config!(model_builder, ort_config).build(model_path)?;
let dict_path = builder.dict_path.ok_or_else(|| OCRError::ConfigError {
message: "Dictionary path is required. Use .dict_path() to specify the path to table_structure_dict_ch.txt".to_string(),
})?;
let decoder = TableStructureDecode::from_dict_path(&dict_path)?;
let mut info = SLANetWirelessAdapterBuilder::base_adapter_info();
if let Some(model_name) = builder.model_name_override {
info.model_name = model_name;
}
Ok(TableStructureRecognitionAdapter::new(
model,
decoder,
info,
task_config,
))
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::traits::adapter::AdapterBuilder;
#[test]
fn test_wired_builder_creation() {
let builder = SLANetWiredAdapterBuilder::new();
assert_eq!(builder.adapter_type(), "table_structure_recognition_wired");
}
#[test]
fn test_wireless_builder_creation() {
let builder = SLANetWirelessAdapterBuilder::new();
assert_eq!(
builder.adapter_type(),
"table_structure_recognition_wireless"
);
}
#[test]
fn test_builder_fluent_api() {
let builder = SLANetWiredAdapterBuilder::new()
.input_shape((640, 640))
.dict_path("models/table_structure_dict_ch.txt");
assert_eq!(builder.input_shape, Some((640, 640)));
}
}