oar_ocr_core/core/traits/
task.rs1use crate::core::OCRError;
8use image::RgbImage;
9use serde::{Deserialize, Serialize};
10use std::fmt::Debug;
11
12crate::with_task_registry!(crate::impl_task_type_enum);
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TaskSchema {
21 pub task_type: TaskType,
23 pub input_types: Vec<String>,
25 pub output_types: Vec<String>,
27 pub metadata_schema: Option<String>,
29}
30
31impl TaskSchema {
32 pub fn new(task_type: TaskType, input_types: Vec<String>, output_types: Vec<String>) -> Self {
34 Self {
35 task_type,
36 input_types,
37 output_types,
38 metadata_schema: None,
39 }
40 }
41
42 pub fn is_compatible_with(&self, target: &TaskSchema) -> bool {
47 self.output_types
49 .iter()
50 .any(|output| target.input_types.contains(output))
51 }
52}
53
54pub trait Task: Send + Sync + Debug {
59 type Config: Send + Sync + Debug + Clone;
61
62 type Input: Send + Sync + Debug;
64
65 type Output: Send + Sync + Debug;
67
68 fn task_type(&self) -> TaskType;
70
71 fn schema(&self) -> TaskSchema;
73
74 fn validate_input(&self, input: &Self::Input) -> Result<(), OCRError>;
84
85 fn validate_output(&self, output: &Self::Output) -> Result<(), OCRError>;
95
96 fn empty_output(&self) -> Self::Output;
98
99 fn description(&self) -> String {
101 format!("Task: {}", self.task_type().name())
102 }
103}
104
105#[derive(Debug)]
110pub struct TaskRunner<T: Task> {
111 task: T,
113 config: T::Config,
115}
116
117impl<T: Task> TaskRunner<T> {
118 pub fn new(task: T, config: T::Config) -> Self {
125 Self { task, config }
126 }
127
128 pub fn task(&self) -> &T {
130 &self.task
131 }
132
133 pub fn config(&self) -> &T::Config {
135 &self.config
136 }
137
138 pub fn task_type(&self) -> TaskType {
140 self.task.task_type()
141 }
142
143 pub fn validate_input(&self, input: &T::Input) -> Result<(), OCRError> {
145 self.task.validate_input(input)
146 }
147
148 pub fn validate_output(&self, output: &T::Output) -> Result<(), OCRError> {
150 self.task.validate_output(output)
151 }
152}
153
154#[derive(Debug, Clone)]
156pub struct ImageTaskInput {
157 pub images: Vec<RgbImage>,
159 pub metadata: Vec<Option<String>>,
161}
162
163impl ImageTaskInput {
164 pub fn new(images: Vec<RgbImage>) -> Self {
166 let count = images.len();
167 Self {
168 images,
169 metadata: vec![None; count],
170 }
171 }
172
173 pub fn with_metadata(images: Vec<RgbImage>, metadata: Vec<Option<String>>) -> Self {
175 Self { images, metadata }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_task_type_name() {
185 assert_eq!(TaskType::TextDetection.name(), "text_detection");
186 assert_eq!(TaskType::TextRecognition.name(), "text_recognition");
187 }
188
189 #[test]
190 fn test_schema_compatibility() {
191 let detection_schema = TaskSchema::new(
192 TaskType::TextDetection,
193 vec!["image".to_string()],
194 vec!["text_boxes".to_string()],
195 );
196
197 let recognition_schema = TaskSchema::new(
198 TaskType::TextRecognition,
199 vec!["text_boxes".to_string()],
200 vec!["text_strings".to_string()],
201 );
202
203 assert!(detection_schema.is_compatible_with(&recognition_schema));
205
206 assert!(!recognition_schema.is_compatible_with(&detection_schema));
208 }
209
210 #[test]
211 fn test_image_task_input_creation() {
212 let images = vec![RgbImage::new(100, 100), RgbImage::new(200, 200)];
213 let input = ImageTaskInput::new(images.clone());
214
215 assert_eq!(input.images.len(), 2);
216 assert_eq!(input.metadata.len(), 2);
217 assert!(input.metadata.iter().all(|m| m.is_none()));
218 }
219
220 #[test]
221 fn test_image_task_input_from_owned() {
222 let images = vec![RgbImage::new(100, 100), RgbImage::new(200, 200)];
223 let input = ImageTaskInput::new(images);
224
225 assert_eq!(input.images.len(), 2);
226 assert_eq!(input.metadata.len(), 2);
227 assert!(input.metadata.iter().all(|m| m.is_none()));
228 }
229}