use crate::apply_ort_config;
use crate::core::OCRError;
use crate::core::traits::{
adapter::{AdapterInfo, ModelAdapter},
task::Task,
};
use crate::domain::tasks::{
Detection, TextDetectionConfig, TextDetectionOutput, TextDetectionTask,
};
use crate::impl_adapter_builder;
use crate::models::detection::db::{DBModel, DBModelBuilder, DBPostprocessConfig};
use crate::processors::{BoxType, ScoreMode};
#[derive(Debug)]
pub struct TextDetectionAdapter {
model: DBModel,
info: AdapterInfo,
config: TextDetectionConfig,
}
impl ModelAdapter for TextDetectionAdapter {
type Task = TextDetectionTask;
fn info(&self) -> AdapterInfo {
self.info.clone()
}
fn execute(
&self,
input: <Self::Task as Task>::Input,
config: Option<&<Self::Task as Task>::Config>,
) -> Result<<Self::Task as Task>::Output, OCRError> {
let effective_config = config.unwrap_or(&self.config);
let model_output = self.model
.forward(
input.images,
effective_config.score_threshold,
effective_config.box_threshold,
effective_config.unclip_ratio,
)
.map_err(|e| {
OCRError::adapter_execution_error(
"TextDetectionAdapter",
format!(
"failed to detect text (score_threshold={}, box_threshold={}, unclip_ratio={})",
effective_config.score_threshold,
effective_config.box_threshold,
effective_config.unclip_ratio
),
e,
)
})?;
let detections = model_output
.boxes
.into_iter()
.zip(model_output.scores)
.map(|(boxes, scores)| {
boxes
.into_iter()
.zip(scores)
.map(|(bbox, score)| Detection::new(bbox, score))
.collect()
})
.collect();
Ok(TextDetectionOutput { detections })
}
fn supports_batching(&self) -> bool {
true
}
fn recommended_batch_size(&self) -> usize {
8
}
}
impl_adapter_builder! {
builder_name: TextDetectionAdapterBuilder,
adapter_name: TextDetectionAdapter,
config_type: TextDetectionConfig,
adapter_type: "text_detection",
adapter_desc: "Detects text regions in images with bounding boxes",
task_type: TextDetection,
fields: {
text_type: Option<String> = None,
},
methods: {
pub fn text_type(mut self, text_type: impl Into<String>) -> Self {
self.text_type = Some(text_type.into());
self
}
}
build: |builder: TextDetectionAdapterBuilder, model_path: &std::path::Path| -> Result<TextDetectionAdapter, OCRError> {
let (task_config, ort_config) = builder.config
.into_validated_parts()
.map_err(|err| OCRError::ConfigError {
message: err.to_string(),
})?;
let is_seal_text = builder
.text_type
.as_ref()
.map(|t| t.to_lowercase() == "seal")
.unwrap_or(false);
let mut preprocess_config =
super::preprocessing::db_preprocess_for_text_type(builder.text_type.as_deref());
if let Some(limit) = task_config.limit_side_len {
preprocess_config.limit_side_len = Some(limit);
}
if let Some(limit_type) = task_config.limit_type.clone() {
preprocess_config.limit_type = Some(limit_type);
}
if let Some(max_limit) = task_config.max_side_len {
preprocess_config.max_side_limit = Some(max_limit);
}
let box_type = if is_seal_text {
BoxType::Poly
} else {
BoxType::Quad
};
let postprocess_config = DBPostprocessConfig {
score_threshold: task_config.score_threshold,
box_threshold: task_config.box_threshold,
unclip_ratio: task_config.unclip_ratio,
max_candidates: task_config.max_candidates,
use_dilation: false,
score_mode: ScoreMode::Fast,
box_type,
};
let model = apply_ort_config!(
DBModelBuilder::new()
.preprocess_config(preprocess_config)
.postprocess_config(postprocess_config),
ort_config
)
.build(model_path)?;
let info = TextDetectionAdapterBuilder::base_adapter_info();
Ok(TextDetectionAdapter {
model,
info,
config: task_config,
})
},
}