use crate::apply_ort_config;
use crate::core::OCRError;
use crate::core::traits::{
adapter::{AdapterInfo, ModelAdapter},
task::Task,
};
use crate::domain::tasks::{
Detection, SealTextDetectionConfig, SealTextDetectionOutput, SealTextDetectionTask,
};
use crate::impl_adapter_builder;
use crate::models::detection::db::{DBModel, DBModelBuilder, DBPostprocessConfig};
use crate::processors::{BoxType, ScoreMode};
#[derive(Debug)]
pub struct SealTextDetectionAdapter {
model: DBModel,
info: AdapterInfo,
config: SealTextDetectionConfig,
}
impl SealTextDetectionAdapter {
pub fn new(model: DBModel, info: AdapterInfo, config: SealTextDetectionConfig) -> Self {
Self {
model,
info,
config,
}
}
}
impl ModelAdapter for SealTextDetectionAdapter {
type Task = SealTextDetectionTask;
fn info(&self) -> AdapterInfo {
self.info.clone()
}
fn execute(
&self,
input: <Self::Task as Task>::Input,
config: Option<&<Self::Task as Task>::Config>,
) -> Result<<Self::Task as Task>::Output, OCRError> {
let effective_config = config.unwrap_or(&self.config);
let model_output = self.model
.forward(
input.images,
effective_config.score_threshold,
effective_config.box_threshold,
effective_config.unclip_ratio,
)
.map_err(|e| {
OCRError::adapter_execution_error(
"SealTextDetectionAdapter",
format!(
"failed to detect seal text (score_threshold={}, box_threshold={}, unclip_ratio={})",
effective_config.score_threshold,
effective_config.box_threshold,
effective_config.unclip_ratio
),
e,
)
})?;
let detections = model_output
.boxes
.into_iter()
.zip(model_output.scores)
.map(|(boxes, scores)| {
boxes
.into_iter()
.zip(scores)
.map(|(bbox, score)| Detection::new(bbox, score))
.collect()
})
.collect();
Ok(SealTextDetectionOutput { detections })
}
fn supports_batching(&self) -> bool {
true
}
fn recommended_batch_size(&self) -> usize {
8
}
}
impl_adapter_builder! {
builder_name: SealTextDetectionAdapterBuilder,
adapter_name: SealTextDetectionAdapter,
config_type: SealTextDetectionConfig,
adapter_type: "seal_text_detection",
adapter_desc: "Detects curved seal text with polygon bounding boxes",
task_type: SealTextDetection,
fields: {},
methods: {}
build: |builder: SealTextDetectionAdapterBuilder, model_path: &std::path::Path| -> Result<SealTextDetectionAdapter, OCRError> {
let (task_config, ort_config) = builder.config
.into_validated_parts()
.map_err(|err| OCRError::ConfigError {
message: err.to_string(),
})?;
let preprocess_config = super::preprocessing::db_preprocess_for_text_type(Some("seal"));
let postprocess_config = DBPostprocessConfig {
score_threshold: task_config.score_threshold,
box_threshold: task_config.box_threshold,
unclip_ratio: task_config.unclip_ratio,
max_candidates: task_config.max_candidates,
use_dilation: false,
score_mode: ScoreMode::Fast,
box_type: BoxType::Poly, };
let model = apply_ort_config!(
DBModelBuilder::new()
.preprocess_config(preprocess_config)
.postprocess_config(postprocess_config),
ort_config
)
.build(model_path)?;
let info = SealTextDetectionAdapterBuilder::base_adapter_info();
Ok(SealTextDetectionAdapter::new(
model,
info,
task_config,
))
},
}