use crate::apply_ort_config;
use crate::core::OCRError;
use crate::core::traits::{
adapter::{AdapterInfo, ModelAdapter},
task::Task,
};
use crate::domain::tasks::{
Classification, TableClassificationConfig, TableClassificationOutput, TableClassificationTask,
};
use crate::impl_adapter_builder;
use crate::models::classification::{PPLCNetModel, PPLCNetModelBuilder, PPLCNetPostprocessConfig};
use std::path::Path;
#[derive(Debug)]
pub struct TableClassificationAdapter {
model: PPLCNetModel,
info: AdapterInfo,
config: TableClassificationConfig,
postprocess_config: PPLCNetPostprocessConfig,
}
impl TableClassificationAdapter {
pub fn new(
model: PPLCNetModel,
info: AdapterInfo,
config: TableClassificationConfig,
postprocess_config: PPLCNetPostprocessConfig,
) -> Self {
Self {
model,
info,
config,
postprocess_config,
}
}
pub const DEFAULT_INPUT_SHAPE: (u32, u32) = (224, 224);
pub fn labels() -> Vec<String> {
vec!["wired_table".to_string(), "wireless_table".to_string()]
}
}
impl ModelAdapter for TableClassificationAdapter {
type Task = TableClassificationTask;
fn info(&self) -> AdapterInfo {
self.info.clone()
}
fn execute(
&self,
input: <Self::Task as Task>::Input,
config: Option<&<Self::Task as Task>::Config>,
) -> Result<<Self::Task as Task>::Output, OCRError> {
let effective_config = config.unwrap_or(&self.config);
let mut postprocess_config = self.postprocess_config.clone();
postprocess_config.topk = effective_config.topk;
let model_output = self
.model
.forward(input.images, &postprocess_config)
.map_err(|e| {
OCRError::adapter_execution_error(
"TableClassificationAdapter",
format!(
"failed to classify table type (topk={})",
effective_config.topk
),
e,
)
})?;
let label_names = model_output.label_names.unwrap_or_else(|| {
model_output
.class_ids
.iter()
.map(|ids| {
ids.iter()
.map(|&id| {
if id == 0 {
"wired_table".to_string()
} else {
"wireless_table".to_string()
}
})
.collect()
})
.collect()
});
let classifications = model_output
.class_ids
.into_iter()
.zip(model_output.scores)
.zip(label_names)
.map(|((class_ids, scores), labels)| {
class_ids
.into_iter()
.zip(scores)
.zip(labels)
.map(|((class_id, score), label)| Classification::new(class_id, label, score))
.collect()
})
.collect();
Ok(TableClassificationOutput { classifications })
}
fn supports_batching(&self) -> bool {
true
}
fn recommended_batch_size(&self) -> usize {
32
}
}
impl_adapter_builder! {
builder_name: TableClassificationAdapterBuilder,
adapter_name: TableClassificationAdapter,
config_type: TableClassificationConfig,
adapter_type: "table_classification",
adapter_desc: "Classifies table images as wired or wireless",
task_type: TableClassification,
fields: {
input_shape: (u32, u32) = TableClassificationAdapter::DEFAULT_INPUT_SHAPE,
model_name_override: Option<String> = None,
},
methods: {
pub fn input_shape(mut self, shape: (u32, u32)) -> Self {
self.input_shape = shape;
self
}
pub fn model_name(mut self, name: impl Into<String>) -> Self {
self.model_name_override = Some(name.into());
self
}
}
build: |builder: TableClassificationAdapterBuilder, model_path: &Path| -> Result<TableClassificationAdapter, OCRError> {
let (task_config, ort_config) = builder.config
.into_validated_parts()
.map_err(|err| OCRError::ConfigError {
message: err.to_string(),
})?;
let preprocess_config = super::preprocessing::pp_lcnet_preprocess(builder.input_shape);
let model = apply_ort_config!(
PPLCNetModelBuilder::new().preprocess_config(preprocess_config),
ort_config
)
.build(model_path)?;
let postprocess_config = PPLCNetPostprocessConfig {
labels: TableClassificationAdapter::labels(),
topk: 1, };
let mut info = TableClassificationAdapterBuilder::base_adapter_info();
if let Some(model_name) = builder.model_name_override {
info.model_name = model_name;
}
Ok(TableClassificationAdapter::new(
model,
info,
task_config,
postprocess_config,
))
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::traits::adapter::AdapterBuilder;
#[test]
fn test_builder_creation() {
let builder = TableClassificationAdapterBuilder::new();
assert_eq!(builder.adapter_type(), "table_classification");
}
#[test]
fn test_builder_fluent_api() {
let builder = TableClassificationAdapterBuilder::new().input_shape((256, 256));
assert_eq!(builder.input_shape, (256, 256));
}
#[test]
fn test_default_builder() {
let builder = TableClassificationAdapterBuilder::default();
assert_eq!(builder.adapter_type(), "table_classification");
}
#[test]
fn test_labels() {
let labels = TableClassificationAdapter::labels();
assert_eq!(labels, vec!["wired_table", "wireless_table"]);
}
}