use crate::core::inference::{OrtInfer, TensorInput};
use crate::core::{OCRError, validate_positive, validate_range};
use crate::processors::{
BoundingBox, BoxType, DBPostProcess, DBPostProcessConfig, DetResizeForTest, ImageScaleInfo,
LimitType, NormalizeImage, ScoreMode, TensorLayout,
};
use image::{DynamicImage, RgbImage};
use std::path::Path;
use tracing::debug;
#[derive(Debug, Clone, Default)]
pub struct DBPreprocessConfig {
pub limit_side_len: Option<u32>,
pub limit_type: Option<LimitType>,
pub max_side_limit: Option<u32>,
pub resize_long: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct DBPostprocessConfig {
pub score_threshold: f32,
pub box_threshold: f32,
pub unclip_ratio: f32,
pub max_candidates: usize,
pub use_dilation: bool,
pub score_mode: ScoreMode,
pub box_type: BoxType,
}
impl Default for DBPostprocessConfig {
fn default() -> Self {
Self {
score_threshold: 0.3,
box_threshold: 0.7,
unclip_ratio: 1.5,
max_candidates: 1000,
use_dilation: false,
score_mode: ScoreMode::Fast,
box_type: BoxType::Quad,
}
}
}
impl DBPostprocessConfig {
pub fn validate(&self) -> Result<(), OCRError> {
validate_range(self.score_threshold, 0.0, 1.0, "score_threshold")?;
validate_range(self.box_threshold, 0.0, 1.0, "box_threshold")?;
validate_positive(self.unclip_ratio, "unclip_ratio")?;
validate_positive(self.max_candidates, "max_candidates")?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DBModelOutput {
pub boxes: Vec<Vec<BoundingBox>>,
pub scores: Vec<Vec<f32>>,
}
#[derive(Debug)]
pub struct DBModel {
inference: OrtInfer,
resizer: DetResizeForTest,
normalizer: NormalizeImage,
postprocessor: DBPostProcess,
}
impl DBModel {
pub fn new(
inference: OrtInfer,
resizer: DetResizeForTest,
normalizer: NormalizeImage,
postprocessor: DBPostProcess,
) -> Self {
Self {
inference,
resizer,
normalizer,
postprocessor,
}
}
pub fn preprocess(
&self,
images: Vec<RgbImage>,
) -> Result<(ndarray::Array4<f32>, Vec<ImageScaleInfo>), OCRError> {
let dynamic_images: Vec<DynamicImage> =
images.into_iter().map(DynamicImage::ImageRgb8).collect();
let (resized_images, img_shapes) = self.resizer.apply(
dynamic_images,
None, None, None, );
debug!("After resize: {} images", resized_images.len());
for (i, (img, shape)) in resized_images.iter().zip(&img_shapes).enumerate() {
debug!(
" Image {}: {}x{}, shape=[src_h={:.0}, src_w={:.0}, ratio_h={:.3}, ratio_w={:.3}]",
i,
img.width(),
img.height(),
shape.src_h,
shape.src_w,
shape.ratio_h,
shape.ratio_w
);
}
let batch_tensor = self.normalizer.normalize_batch_to(resized_images)?;
debug!("Batch tensor shape: {:?}", batch_tensor.shape());
Ok((batch_tensor, img_shapes))
}
pub fn infer(
&self,
batch_tensor: &ndarray::Array4<f32>,
) -> Result<ndarray::Array4<f32>, OCRError> {
let input_name = self.inference.input_name();
let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
let outputs = self
.inference
.infer(&inputs)
.map_err(|e| OCRError::Inference {
model_name: "DB".to_string(),
context: format!(
"failed to run inference on batch with shape {:?}",
batch_tensor.shape()
),
source: Box::new(e),
})?;
let output = outputs
.into_iter()
.next()
.ok_or_else(|| OCRError::InvalidInput {
message: "DB: no output returned from inference".to_string(),
})?;
output
.1
.try_into_array4_f32()
.map_err(|e| OCRError::Inference {
model_name: "DB".to_string(),
context: "failed to convert output to 4D array".to_string(),
source: Box::new(e),
})
}
pub fn postprocess(
&self,
predictions: &ndarray::Array4<f32>,
img_shapes: Vec<ImageScaleInfo>,
score_threshold: f32,
box_threshold: f32,
unclip_ratio: f32,
) -> DBModelOutput {
let config = DBPostProcessConfig::new(score_threshold, box_threshold, unclip_ratio);
let (boxes, scores) = self
.postprocessor
.apply(predictions, img_shapes, Some(&config));
DBModelOutput { boxes, scores }
}
pub fn forward(
&self,
images: Vec<RgbImage>,
score_threshold: f32,
box_threshold: f32,
unclip_ratio: f32,
) -> Result<DBModelOutput, OCRError> {
let (batch_tensor, img_shapes) = self.preprocess(images)?;
let predictions = self.infer(&batch_tensor)?;
Ok(self.postprocess(
&predictions,
img_shapes,
score_threshold,
box_threshold,
unclip_ratio,
))
}
}
pub struct DBModelBuilder {
preprocess_config: DBPreprocessConfig,
postprocess_config: DBPostprocessConfig,
ort_config: Option<crate::core::config::OrtSessionConfig>,
}
impl DBModelBuilder {
pub fn new() -> Self {
Self {
preprocess_config: DBPreprocessConfig::default(),
postprocess_config: DBPostprocessConfig::default(),
ort_config: None,
}
}
pub fn preprocess_config(mut self, config: DBPreprocessConfig) -> Self {
self.preprocess_config = config;
self
}
pub fn postprocess_config(mut self, config: DBPostprocessConfig) -> Self {
self.postprocess_config = config;
self
}
pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
self.ort_config = Some(config);
self
}
pub fn build(self, model_path: &Path) -> Result<DBModel, OCRError> {
let inference = if self.ort_config.is_some() {
use crate::core::config::ModelInferenceConfig;
let common_config = ModelInferenceConfig {
ort_session: self.ort_config,
..Default::default()
};
OrtInfer::from_config(&common_config, model_path, Some("x"))?
} else {
OrtInfer::new(model_path, Some("x"))?
};
let resizer = DetResizeForTest::new(
None, None, None, self.preprocess_config.limit_side_len, self.preprocess_config.limit_type, self.preprocess_config.resize_long, self.preprocess_config.max_side_limit, );
let normalizer = NormalizeImage::with_color_order(
Some(1.0 / 255.0), Some(vec![0.485, 0.456, 0.406]), Some(vec![0.229, 0.224, 0.225]), Some(TensorLayout::CHW), Some(crate::processors::types::ColorOrder::BGR),
)?;
let postprocessor = DBPostProcess::new(
Some(self.postprocess_config.score_threshold),
Some(self.postprocess_config.box_threshold),
Some(self.postprocess_config.max_candidates),
Some(self.postprocess_config.unclip_ratio),
Some(self.postprocess_config.use_dilation),
Some(self.postprocess_config.score_mode),
Some(self.postprocess_config.box_type),
);
Ok(DBModel::new(inference, resizer, normalizer, postprocessor))
}
}
impl Default for DBModelBuilder {
fn default() -> Self {
Self::new()
}
}