use crate::core::OCRError;
use image::RgbImage;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
crate::with_task_registry!(crate::impl_task_type_enum);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskSchema {
pub task_type: TaskType,
pub input_types: Vec<String>,
pub output_types: Vec<String>,
pub metadata_schema: Option<String>,
}
impl TaskSchema {
pub fn new(task_type: TaskType, input_types: Vec<String>, output_types: Vec<String>) -> Self {
Self {
task_type,
input_types,
output_types,
metadata_schema: None,
}
}
pub fn is_compatible_with(&self, target: &TaskSchema) -> bool {
self.output_types
.iter()
.any(|output| target.input_types.contains(output))
}
}
pub trait Task: Send + Sync + Debug {
type Config: Send + Sync + Debug + Clone;
type Input: Send + Sync + Debug;
type Output: Send + Sync + Debug;
fn task_type(&self) -> TaskType;
fn schema(&self) -> TaskSchema;
fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError>;
fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError>;
fn empty_output(&self) -> Self::Output;
fn description(&self) -> String {
format!("Task: {}", self.task_type().name())
}
}
#[derive(Debug)]
pub struct TaskRunner<T: Task> {
task: T,
config: T::Config,
}
impl<T: Task> TaskRunner<T> {
pub fn new(task: T, config: T::Config) -> Self {
Self { task, config }
}
pub fn task(&self) -> &T {
&self.task
}
pub fn config(&self) -> &T::Config {
&self.config
}
pub fn task_type(&self) -> TaskType {
self.task.task_type()
}
pub fn validate_input(&self, input: &T::Input) -> Result<(), OCRError> {
self.task.validate_input(input)
}
pub fn validate_output(&self, output: &T::Output) -> Result<(), OCRError> {
self.task.validate_output(output)
}
}
#[derive(Debug, Clone)]
pub struct ImageTaskInput {
pub images: Vec<RgbImage>,
pub metadata: Vec<Option<String>>,
}
impl ImageTaskInput {
pub fn new(images: Vec<RgbImage>) -> Self {
let count = images.len();
Self {
images,
metadata: vec![None; count],
}
}
pub fn with_metadata(images: Vec<RgbImage>, metadata: Vec<Option<String>>) -> Self {
Self { images, metadata }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_type_name() {
assert_eq!(TaskType::TextDetection.name(), "text_detection");
assert_eq!(TaskType::TextRecognition.name(), "text_recognition");
}
#[test]
fn test_schema_compatibility() {
let detection_schema = TaskSchema::new(
TaskType::TextDetection,
vec!["image".to_string()],
vec!["text_boxes".to_string()],
);
let recognition_schema = TaskSchema::new(
TaskType::TextRecognition,
vec!["text_boxes".to_string()],
vec!["text_strings".to_string()],
);
assert!(detection_schema.is_compatible_with(&recognition_schema));
assert!(!recognition_schema.is_compatible_with(&detection_schema));
}
#[test]
fn test_image_task_input_creation() {
let images = vec![RgbImage::new(100, 100), RgbImage::new(200, 200)];
let input = ImageTaskInput::new(images.clone());
assert_eq!(input.images.len(), 2);
assert_eq!(input.metadata.len(), 2);
assert!(input.metadata.iter().all(|m| m.is_none()));
}
#[test]
fn test_image_task_input_from_owned() {
let images = vec![RgbImage::new(100, 100), RgbImage::new(200, 200)];
let input = ImageTaskInput::new(images);
assert_eq!(input.images.len(), 2);
assert_eq!(input.metadata.len(), 2);
assert!(input.metadata.iter().all(|m| m.is_none()));
}
}