use super::validation::ensure_non_empty_images;
use crate::ConfigValidator;
use crate::core::OCRError;
use crate::core::traits::TaskDefinition;
use crate::core::traits::task::{ImageTaskInput, Task, TaskSchema, TaskType};
use crate::processors::{BoundingBox, LimitType};
use crate::utils::ScoreValidator;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct Detection {
pub bbox: BoundingBox,
pub score: f32,
}
impl Detection {
pub fn new(bbox: BoundingBox, score: f32) -> Self {
Self { bbox, score }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)]
pub struct TextDetectionConfig {
#[validate(range(min = 0.0, max = 1.0))]
pub score_threshold: f32,
#[validate(range(min = 0.0, max = 1.0))]
pub box_threshold: f32,
#[validate(min = 0.0)]
pub unclip_ratio: f32,
#[validate(min = 1)]
pub max_candidates: usize,
pub limit_side_len: Option<u32>,
pub limit_type: Option<LimitType>,
pub max_side_len: Option<u32>,
}
impl Default for TextDetectionConfig {
fn default() -> Self {
Self {
score_threshold: 0.3,
box_threshold: 0.6,
unclip_ratio: 1.5,
max_candidates: 1000,
limit_side_len: None,
limit_type: None,
max_side_len: None,
}
}
}
#[derive(Debug, Clone)]
pub struct TextDetectionOutput {
pub detections: Vec<Vec<Detection>>,
}
impl TextDetectionOutput {
pub fn empty() -> Self {
Self {
detections: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
detections: Vec::with_capacity(capacity),
}
}
}
impl TaskDefinition for TextDetectionOutput {
const TASK_NAME: &'static str = "text_detection";
const TASK_DOC: &'static str = "Text detection - locating text regions in images";
fn empty() -> Self {
TextDetectionOutput::empty()
}
}
#[derive(Debug, Default)]
pub struct TextDetectionTask {
_config: TextDetectionConfig,
}
impl TextDetectionTask {
pub fn new(config: TextDetectionConfig) -> Self {
Self { _config: config }
}
}
impl Task for TextDetectionTask {
type Config = TextDetectionConfig;
type Input = ImageTaskInput;
type Output = TextDetectionOutput;
fn task_type(&self) -> TaskType {
TaskType::TextDetection
}
fn schema(&self) -> TaskSchema {
TaskSchema::new(
TaskType::TextDetection,
vec!["image".to_string()],
vec!["text_boxes".to_string(), "scores".to_string()],
)
}
fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError> {
ensure_non_empty_images(&input.images, "No images provided for text detection")?;
Ok(())
}
fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError> {
let validator = ScoreValidator::new_unit_range("score");
for (idx, detections) in output.detections.iter().enumerate() {
let scores: Vec<f32> = detections.iter().map(|d| d.score).collect();
validator.validate_scores_with(&scores, |det_idx| {
format!("Image {}, detection {}", idx, det_idx)
})?;
}
Ok(())
}
fn empty_output(&self) -> Self::Output {
TextDetectionOutput::empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::processors::Point;
use image::RgbImage;
#[test]
fn test_text_detection_task_creation() {
let task = TextDetectionTask::default();
assert_eq!(task.task_type(), TaskType::TextDetection);
}
#[test]
fn test_input_validation() {
let task = TextDetectionTask::default();
let empty_input = ImageTaskInput::new(vec![]);
assert!(task.validate_input(&empty_input).is_err());
let valid_input = ImageTaskInput::new(vec![RgbImage::new(100, 100)]);
assert!(task.validate_input(&valid_input).is_ok());
}
#[test]
fn test_output_validation() {
let task = TextDetectionTask::default();
let box1 = BoundingBox::new(vec![
Point::new(0.0, 0.0),
Point::new(10.0, 0.0),
Point::new(10.0, 10.0),
Point::new(0.0, 10.0),
]);
let detection1 = Detection::new(box1, 0.95);
let output = TextDetectionOutput {
detections: vec![vec![detection1]],
};
assert!(task.validate_output(&output).is_ok());
let box2 = BoundingBox::new(vec![
Point::new(0.0, 0.0),
Point::new(10.0, 0.0),
Point::new(10.0, 10.0),
Point::new(0.0, 10.0),
]);
let detection2 = Detection::new(box2, 1.5); let bad_output = TextDetectionOutput {
detections: vec![vec![detection2]],
};
assert!(task.validate_output(&bad_output).is_err());
}
#[test]
fn test_schema() {
let task = TextDetectionTask::default();
let schema = task.schema();
assert_eq!(schema.task_type, TaskType::TextDetection);
assert!(schema.input_types.contains(&"image".to_string()));
assert!(schema.output_types.contains(&"text_boxes".to_string()));
}
}