oar_ocr/core/traits/standard.rs
1//! Traits for the OCR pipeline.
2//!
3//! This module defines various traits that are used throughout the OCR pipeline
4//! to provide a consistent interface for different components.
5
6use crate::core::{PredictionResult, batch::BatchData, batch::BatchSampler, errors::OCRError};
7use image::RgbImage;
8use std::path::Path;
9use std::sync::Arc;
10
11/// Trait for building predictors.
12///
13/// This trait defines the interface for building predictors with specific configurations.
14pub trait PredictorBuilder: Sized {
15 /// The configuration type for this builder.
16 type Config;
17
18 /// The predictor type that this builder creates.
19 type Predictor;
20
21 /// Builds a typed predictor.
22 ///
23 /// # Arguments
24 ///
25 /// * `model_path` - The path to the model file.
26 ///
27 /// # Returns
28 ///
29 /// A Result containing the built predictor or an error.
30 fn build_typed(self, model_path: &Path) -> crate::core::OcrResult<Self::Predictor>;
31
32 /// Gets the type of predictor that this builder creates.
33 ///
34 /// # Returns
35 ///
36 /// The predictor type as a string.
37 fn predictor_type(&self) -> &str;
38
39 /// Configures the builder with the given configuration.
40 ///
41 /// # Arguments
42 ///
43 /// * `config` - The configuration to use.
44 ///
45 /// # Returns
46 ///
47 /// The configured builder.
48 fn with_config(self, config: Self::Config) -> Self;
49
50 /// Builds a predictor (alias for build_typed).
51 ///
52 /// # Arguments
53 ///
54 /// * `model_path` - The path to the model file.
55 ///
56 /// # Returns
57 ///
58 /// A Result containing the built predictor or an error.
59 fn build(self, model_path: &Path) -> crate::core::OcrResult<Self::Predictor> {
60 self.build_typed(model_path)
61 }
62}
63
64/// Trait for sampling data into batches.
65///
66/// This trait defines the interface for sampling data into batches for processing.
67pub trait Sampler<T> {
68 /// The type of batch data produced by this sampler.
69 type BatchData;
70
71 /// Samples the given data into batches.
72 ///
73 /// # Arguments
74 ///
75 /// * `data` - The data to sample.
76 ///
77 /// # Returns
78 ///
79 /// A vector of batch data.
80 fn sample(&self, data: Vec<T>) -> Vec<Self::BatchData>;
81
82 /// Samples the given slice of data into batches.
83 ///
84 /// # Arguments
85 ///
86 /// * `data` - The slice of data to sample.
87 ///
88 /// # Returns
89 ///
90 /// A vector of batch data.
91 ///
92 /// # Constraints
93 ///
94 /// * `T` must implement Clone.
95 fn sample_slice(&self, data: &[T]) -> Vec<Self::BatchData>
96 where
97 T: Clone,
98 {
99 self.sample(data.to_vec())
100 }
101
102 /// Samples the given iterator of data into batches.
103 ///
104 /// # Arguments
105 ///
106 /// * `data` - The iterator of data to sample.
107 ///
108 /// # Returns
109 ///
110 /// A vector of batch data.
111 ///
112 /// # Constraints
113 ///
114 /// * `I` must implement IntoIterator<Item = T>.
115 fn sample_iter<I>(&self, data: I) -> Vec<Self::BatchData>
116 where
117 I: IntoIterator<Item = T>,
118 {
119 self.sample(data.into_iter().collect())
120 }
121}
122
123/// Trait for base predictors in the OCR pipeline.
124///
125/// This trait defines the interface for base predictors that process batch data.
126pub trait BasePredictor {
127 /// The result type of this predictor.
128 type Result;
129
130 /// The error type of this predictor.
131 type Error;
132
133 /// Processes the given batch data.
134 ///
135 /// # Arguments
136 ///
137 /// * `batch_data` - The batch data to process.
138 ///
139 /// # Returns
140 ///
141 /// A Result containing the processing result or an error.
142 fn process(&self, batch_data: BatchData) -> Result<Self::Result, Self::Error>;
143
144 /// Converts the processing result to a prediction result.
145 ///
146 /// # Arguments
147 ///
148 /// * `result` - The processing result to convert.
149 ///
150 /// # Returns
151 ///
152 /// The converted prediction result.
153 fn convert_to_prediction_result(&self, result: Self::Result) -> PredictionResult<'static>;
154
155 /// Gets the batch sampler used by this predictor.
156 ///
157 /// # Returns
158 ///
159 /// A reference to the batch sampler.
160 fn batch_sampler(&self) -> &BatchSampler;
161
162 /// Gets the name of the model used by this predictor.
163 ///
164 /// # Returns
165 ///
166 /// The name of the model.
167 fn model_name(&self) -> &str;
168
169 /// Gets the type name of this predictor.
170 ///
171 /// # Returns
172 ///
173 /// The type name of the predictor.
174 fn predictor_type_name(&self) -> &str;
175}
176
177/// Trait for reading images.
178///
179/// This trait defines the interface for reading images from paths.
180pub trait ImageReader {
181 /// The error type of this image reader.
182 type Error;
183
184 /// Applies the image reader to the given paths.
185 ///
186 /// # Arguments
187 ///
188 /// * `imgs` - An iterator of paths to the images to read.
189 ///
190 /// # Returns
191 ///
192 /// A Result containing a vector of RGB images or an error.
193 ///
194 /// # Constraints
195 ///
196 /// * `P` must implement `AsRef<Path>` + Send + Sync.
197 fn apply<P: AsRef<Path> + Send + Sync>(
198 &self,
199 imgs: impl IntoIterator<Item = P>,
200 ) -> Result<Vec<RgbImage>, Self::Error>;
201
202 /// Reads a single image from the given path.
203 ///
204 /// # Arguments
205 ///
206 /// * `img_path` - The path to the image to read.
207 ///
208 /// # Returns
209 ///
210 /// A Result containing the RGB image or an error.
211 ///
212 /// # Constraints
213 ///
214 /// * `P` must implement `AsRef<Path>` + Send + Sync.
215 fn read_single<P: AsRef<Path> + Send + Sync>(
216 &self,
217 img_path: P,
218 ) -> Result<RgbImage, Self::Error>
219 where
220 Self::Error: From<OCRError>,
221 {
222 let mut results = self.apply(std::iter::once(img_path))?;
223 results.pop().ok_or_else(|| {
224 // Create a proper error instead of panicking
225 OCRError::invalid_input("ImageReader::apply returned empty result for single image")
226 .into()
227 })
228 }
229}
230
231/// Trait for predictor configurations.
232///
233/// This trait defines the interface for predictor configurations.
234pub trait PredictorConfig {
235 /// Gets the name of the model.
236 ///
237 /// # Returns
238 ///
239 /// The name of the model.
240 fn model_name(&self) -> &str;
241
242 /// Gets the batch size.
243 ///
244 /// # Returns
245 ///
246 /// The batch size.
247 fn batch_size(&self) -> usize;
248
249 /// Validates the configuration.
250 ///
251 /// # Returns
252 ///
253 /// A Result indicating success or an error.
254 fn validate(&self) -> crate::core::OcrResult<()>;
255
256 /// Validates the batch size.
257 ///
258 /// # Returns
259 ///
260 /// A Result indicating success or an error.
261 ///
262 /// # Validation Rules
263 ///
264 /// * Batch size must be greater than 0.
265 /// * Batch size should not exceed 1000 for memory efficiency.
266 fn validate_batch_size(&self) -> crate::core::OcrResult<()> {
267 let batch_size = self.batch_size();
268 if batch_size == 0 {
269 return Err(OCRError::ConfigError {
270 message: "Batch size must be greater than 0".to_string(),
271 });
272 }
273 if batch_size > 1000 {
274 return Err(OCRError::ConfigError {
275 message: "Batch size should not exceed 1000 for memory efficiency".to_string(),
276 });
277 }
278 Ok(())
279 }
280
281 /// Validates the model name.
282 ///
283 /// # Returns
284 ///
285 /// A Result indicating success or an error.
286 ///
287 /// # Validation Rules
288 ///
289 /// * Model name cannot be empty.
290 fn validate_model_name(&self) -> crate::core::OcrResult<()> {
291 let name = self.model_name();
292 if name.is_empty() {
293 return Err(OCRError::ConfigError {
294 message: "Model name cannot be empty".to_string(),
295 });
296 }
297 Ok(())
298 }
299}
300
301/// Trait for standard predictors.
302///
303/// This trait defines the interface for standard predictors that follow
304/// a standard pipeline of reading images, preprocessing, inference, and postprocessing.
305pub trait StandardPredictor {
306 /// The configuration type for this predictor.
307 type Config;
308
309 /// The result type of this predictor.
310 type Result;
311
312 /// The output type of the preprocessing step.
313 type PreprocessOutput;
314
315 /// The output type of the inference step.
316 type InferenceOutput;
317
318 /// Reads images from the given paths.
319 ///
320 /// # Arguments
321 ///
322 /// * `paths` - An iterator of paths to the images to read.
323 ///
324 /// # Returns
325 ///
326 /// A Result containing a vector of RGB images or an error.
327 fn read_images<'a>(
328 &self,
329 paths: impl Iterator<Item = &'a str>,
330 ) -> Result<Vec<RgbImage>, OCRError>;
331
332 /// Preprocesses the given images.
333 ///
334 /// # Arguments
335 ///
336 /// * `images` - The images to preprocess.
337 /// * `config` - Optional configuration for preprocessing.
338 ///
339 /// # Returns
340 ///
341 /// A Result containing the preprocessed output or an error.
342 fn preprocess(
343 &self,
344 images: Vec<RgbImage>,
345 config: Option<&Self::Config>,
346 ) -> crate::core::OcrResult<Self::PreprocessOutput>;
347
348 /// Performs inference on the preprocessed input.
349 ///
350 /// # Arguments
351 ///
352 /// * `input` - The preprocessed input.
353 ///
354 /// # Returns
355 ///
356 /// A Result containing the inference output or an error.
357 fn infer(
358 &self,
359 input: &Self::PreprocessOutput,
360 ) -> crate::core::OcrResult<Self::InferenceOutput>;
361
362 /// Postprocesses the inference output.
363 ///
364 /// # Arguments
365 ///
366 /// * `output` - The inference output to postprocess.
367 /// * `preprocessed` - The preprocessed input.
368 /// * `batch_data` - The batch data.
369 /// * `raw_images` - The raw images.
370 /// * `config` - Optional configuration for postprocessing.
371 ///
372 /// # Returns
373 ///
374 /// A Result containing the final result or an error.
375 fn postprocess(
376 &self,
377 output: Self::InferenceOutput,
378 preprocessed: &Self::PreprocessOutput,
379 batch_data: &BatchData,
380 raw_images: Vec<RgbImage>,
381 config: Option<&Self::Config>,
382 ) -> crate::core::OcrResult<Self::Result>;
383
384 /// Performs prediction directly from in-memory images.
385 ///
386 /// This method bypasses file I/O by working directly with RgbImage instances,
387 /// providing better performance when images are already in memory. This is
388 /// the primary prediction method for most use cases.
389 ///
390 /// # Arguments
391 ///
392 /// * `images` - Vector of images to process
393 /// * `config` - Optional configuration for the prediction
394 ///
395 /// # Returns
396 ///
397 /// A Result containing the prediction result or an OCRError
398 fn predict(
399 &self,
400 images: Vec<RgbImage>,
401 config: Option<Self::Config>,
402 ) -> crate::core::OcrResult<Self::Result> {
403 if images.is_empty() {
404 return self.empty_result();
405 }
406
407 let batch_data = self.create_dummy_batch_data(images.len());
408 let preprocessed = self.preprocess(images.clone(), config.as_ref())?;
409 let inference_output = self.infer(&preprocessed)?;
410 self.postprocess(
411 inference_output,
412 &preprocessed,
413 &batch_data,
414 images,
415 config.as_ref(),
416 )
417 }
418
419 /// Creates dummy batch data for in-memory processing.
420 ///
421 /// This method creates BatchData with dummy paths for in-memory processing,
422 /// allowing the postprocessing step to work correctly without actual file paths.
423 ///
424 /// # Arguments
425 ///
426 /// * `count` - Number of images to create batch data for
427 ///
428 /// # Returns
429 ///
430 /// BatchData with dummy paths and sequential indexes
431 fn create_dummy_batch_data(&self, count: usize) -> BatchData {
432 let dummy_paths: Vec<Arc<str>> = (0..count)
433 .map(|i| Arc::from(format!("in_memory_{i}")))
434 .collect();
435 BatchData {
436 instances: dummy_paths.clone(),
437 input_paths: dummy_paths,
438 indexes: (0..count).collect(),
439 }
440 }
441
442 /// Returns an empty result for the predictor type.
443 ///
444 /// This method should return an empty instance of the result type,
445 /// typically used when processing an empty list of images.
446 ///
447 /// # Returns
448 ///
449 /// A Result containing an empty result instance
450 fn empty_result(&self) -> crate::core::OcrResult<Self::Result>;
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::core::OCRError;
457 use image::RgbImage;
458 use std::path::Path;
459
460 /// Mock ImageReader that always returns empty results to test error handling
461 struct MockEmptyImageReader;
462
463 impl ImageReader for MockEmptyImageReader {
464 type Error = OCRError;
465
466 fn apply<P: AsRef<Path> + Send + Sync>(
467 &self,
468 _imgs: impl IntoIterator<Item = P>,
469 ) -> Result<Vec<RgbImage>, Self::Error> {
470 // Always return empty vector to trigger the error condition
471 Ok(Vec::new())
472 }
473 }
474
475 #[test]
476 fn test_read_single_handles_empty_result_properly() {
477 let reader = MockEmptyImageReader;
478 let result = reader.read_single("test_path.jpg");
479
480 // Should return an error instead of panicking
481 assert!(result.is_err());
482
483 // Check that it's the correct error type
484 match result.unwrap_err() {
485 OCRError::InvalidInput { message } => {
486 assert!(
487 message.contains("ImageReader::apply returned empty result for single image")
488 );
489 }
490 _ => panic!("Expected InvalidInput error"),
491 }
492 }
493}