oar_ocr_core/core/traits/
task.rs

1//! Task trait definitions for the OCR pipeline.
2//!
3//! This module defines the `Task` trait and related types that represent
4//! different OCR tasks (text detection, recognition, layout analysis, etc.).
5//! Tasks define input/output schemas and execution contracts.
6
7use crate::core::OCRError;
8use image::RgbImage;
9use serde::{Deserialize, Serialize};
10use std::fmt::Debug;
11
12// Generate TaskType enum from the central task registry
13crate::with_task_registry!(crate::impl_task_type_enum);
14
15/// Schema definition for task inputs and outputs.
16///
17/// This allows for validation that models produce outputs compatible
18/// with what downstream tasks expect.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TaskSchema {
21    /// Task type
22    pub task_type: TaskType,
23    /// Expected input types (e.g., "image", "text_boxes")
24    pub input_types: Vec<String>,
25    /// Expected output types (e.g., "text_boxes", "text_strings")
26    pub output_types: Vec<String>,
27    /// Optional metadata schema
28    pub metadata_schema: Option<String>,
29}
30
31impl TaskSchema {
32    /// Creates a new task schema.
33    pub fn new(task_type: TaskType, input_types: Vec<String>, output_types: Vec<String>) -> Self {
34        Self {
35            task_type,
36            input_types,
37            output_types,
38            metadata_schema: None,
39        }
40    }
41
42    /// Validates that this schema is compatible with another schema.
43    ///
44    /// Returns true if the output types of this schema match the input types
45    /// of the target schema.
46    pub fn is_compatible_with(&self, target: &TaskSchema) -> bool {
47        // Check if any of our output types match any of target's input types
48        self.output_types
49            .iter()
50            .any(|output| target.input_types.contains(output))
51    }
52}
53
54/// Core trait for OCR tasks.
55///
56/// Tasks represent distinct operations in the OCR pipeline (detection, recognition, etc.).
57/// Each task defines its input/output schema and can be executed with various model adapters.
58pub trait Task: Send + Sync + Debug {
59    /// Configuration type for this task
60    type Config: Send + Sync + Debug + Clone;
61
62    /// Input type for this task
63    type Input: Send + Sync + Debug;
64
65    /// Output type from this task
66    type Output: Send + Sync + Debug;
67
68    /// Returns the task type identifier.
69    fn task_type(&self) -> TaskType;
70
71    /// Returns the schema defining inputs and outputs for this task.
72    fn schema(&self) -> TaskSchema;
73
74    /// Validates that the given input is suitable for this task.
75    ///
76    /// # Arguments
77    ///
78    /// * `input` - The input to validate
79    ///
80    /// # Returns
81    ///
82    /// Result indicating success or validation error
83    fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError>;
84
85    /// Validates that the given output matches the expected schema.
86    ///
87    /// # Arguments
88    ///
89    /// * `output` - The output to validate
90    ///
91    /// # Returns
92    ///
93    /// Result indicating success or validation error
94    fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError>;
95
96    /// Returns an empty output instance for when no valid results are produced.
97    fn empty_output(&self) -> Self::Output;
98
99    /// Returns a human-readable description of this task.
100    fn description(&self) -> String {
101        format!("Task: {}", self.task_type().name())
102    }
103}
104
105/// A task runner that executes tasks using a model adapter.
106///
107/// This struct coordinates the execution of a task with a specific model,
108/// handling validation and error propagation.
109#[derive(Debug)]
110pub struct TaskRunner<T: Task> {
111    /// The task to execute
112    task: T,
113    /// Configuration for the task
114    config: T::Config,
115}
116
117impl<T: Task> TaskRunner<T> {
118    /// Creates a new task runner.
119    ///
120    /// # Arguments
121    ///
122    /// * `task` - The task to execute
123    /// * `config` - Configuration for the task
124    pub fn new(task: T, config: T::Config) -> Self {
125        Self { task, config }
126    }
127
128    /// Returns a reference to the task.
129    pub fn task(&self) -> &T {
130        &self.task
131    }
132
133    /// Returns a reference to the configuration.
134    pub fn config(&self) -> &T::Config {
135        &self.config
136    }
137
138    /// Returns the task type.
139    pub fn task_type(&self) -> TaskType {
140        self.task.task_type()
141    }
142
143    /// Validates input before execution.
144    pub fn validate_input(&self, input: &T::Input) -> Result<(), OCRError> {
145        self.task.validate_input(input)
146    }
147
148    /// Validates output after execution.
149    pub fn validate_output(&self, output: &T::Output) -> Result<(), OCRError> {
150        self.task.validate_output(output)
151    }
152}
153
154/// Common input type for image-based tasks.
155#[derive(Debug, Clone)]
156pub struct ImageTaskInput {
157    /// Input images
158    pub images: Vec<RgbImage>,
159    /// Optional metadata per image
160    pub metadata: Vec<Option<String>>,
161}
162
163impl ImageTaskInput {
164    /// Creates a new image task input from owned images.
165    pub fn new(images: Vec<RgbImage>) -> Self {
166        let count = images.len();
167        Self {
168            images,
169            metadata: vec![None; count],
170        }
171    }
172
173    /// Creates a new image task input with metadata.
174    pub fn with_metadata(images: Vec<RgbImage>, metadata: Vec<Option<String>>) -> Self {
175        Self { images, metadata }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_task_type_name() {
185        assert_eq!(TaskType::TextDetection.name(), "text_detection");
186        assert_eq!(TaskType::TextRecognition.name(), "text_recognition");
187    }
188
189    #[test]
190    fn test_schema_compatibility() {
191        let detection_schema = TaskSchema::new(
192            TaskType::TextDetection,
193            vec!["image".to_string()],
194            vec!["text_boxes".to_string()],
195        );
196
197        let recognition_schema = TaskSchema::new(
198            TaskType::TextRecognition,
199            vec!["text_boxes".to_string()],
200            vec!["text_strings".to_string()],
201        );
202
203        // Detection output (text_boxes) should be compatible with recognition input (text_boxes)
204        assert!(detection_schema.is_compatible_with(&recognition_schema));
205
206        // Recognition output (text_strings) is not compatible with detection input (image)
207        assert!(!recognition_schema.is_compatible_with(&detection_schema));
208    }
209
210    #[test]
211    fn test_image_task_input_creation() {
212        let images = vec![RgbImage::new(100, 100), RgbImage::new(200, 200)];
213        let input = ImageTaskInput::new(images.clone());
214
215        assert_eq!(input.images.len(), 2);
216        assert_eq!(input.metadata.len(), 2);
217        assert!(input.metadata.iter().all(|m| m.is_none()));
218    }
219
220    #[test]
221    fn test_image_task_input_from_owned() {
222        let images = vec![RgbImage::new(100, 100), RgbImage::new(200, 200)];
223        let input = ImageTaskInput::new(images);
224
225        assert_eq!(input.images.len(), 2);
226        assert_eq!(input.metadata.len(), 2);
227        assert!(input.metadata.iter().all(|m| m.is_none()));
228    }
229}