oar_ocr/core/traits/
standard.rs

1//! Traits for the OCR pipeline.
2//!
3//! This module defines various traits that are used throughout the OCR pipeline
4//! to provide a consistent interface for different components.
5
6use crate::core::{PredictionResult, batch::BatchData, batch::BatchSampler, errors::OCRError};
7use image::RgbImage;
8use std::path::Path;
9use std::sync::Arc;
10
11/// Trait for building predictors.
12///
13/// This trait defines the interface for building predictors with specific configurations.
14pub trait PredictorBuilder: Sized {
15    /// The configuration type for this builder.
16    type Config;
17
18    /// The predictor type that this builder creates.
19    type Predictor;
20
21    /// Builds a typed predictor.
22    ///
23    /// # Arguments
24    ///
25    /// * `model_path` - The path to the model file.
26    ///
27    /// # Returns
28    ///
29    /// A Result containing the built predictor or an error.
30    fn build_typed(self, model_path: &Path) -> crate::core::OcrResult<Self::Predictor>;
31
32    /// Gets the type of predictor that this builder creates.
33    ///
34    /// # Returns
35    ///
36    /// The predictor type as a string.
37    fn predictor_type(&self) -> &str;
38
39    /// Configures the builder with the given configuration.
40    ///
41    /// # Arguments
42    ///
43    /// * `config` - The configuration to use.
44    ///
45    /// # Returns
46    ///
47    /// The configured builder.
48    fn with_config(self, config: Self::Config) -> Self;
49
50    /// Builds a predictor (alias for build_typed).
51    ///
52    /// # Arguments
53    ///
54    /// * `model_path` - The path to the model file.
55    ///
56    /// # Returns
57    ///
58    /// A Result containing the built predictor or an error.
59    fn build(self, model_path: &Path) -> crate::core::OcrResult<Self::Predictor> {
60        self.build_typed(model_path)
61    }
62}
63
64/// Trait for sampling data into batches.
65///
66/// This trait defines the interface for sampling data into batches for processing.
67pub trait Sampler<T> {
68    /// The type of batch data produced by this sampler.
69    type BatchData;
70
71    /// Samples the given data into batches.
72    ///
73    /// # Arguments
74    ///
75    /// * `data` - The data to sample.
76    ///
77    /// # Returns
78    ///
79    /// A vector of batch data.
80    fn sample(&self, data: Vec<T>) -> Vec<Self::BatchData>;
81
82    /// Samples the given slice of data into batches.
83    ///
84    /// # Arguments
85    ///
86    /// * `data` - The slice of data to sample.
87    ///
88    /// # Returns
89    ///
90    /// A vector of batch data.
91    ///
92    /// # Constraints
93    ///
94    /// * `T` must implement Clone.
95    fn sample_slice(&self, data: &[T]) -> Vec<Self::BatchData>
96    where
97        T: Clone,
98    {
99        self.sample(data.to_vec())
100    }
101
102    /// Samples the given iterator of data into batches.
103    ///
104    /// # Arguments
105    ///
106    /// * `data` - The iterator of data to sample.
107    ///
108    /// # Returns
109    ///
110    /// A vector of batch data.
111    ///
112    /// # Constraints
113    ///
114    /// * `I` must implement IntoIterator<Item = T>.
115    fn sample_iter<I>(&self, data: I) -> Vec<Self::BatchData>
116    where
117        I: IntoIterator<Item = T>,
118    {
119        self.sample(data.into_iter().collect())
120    }
121}
122
123/// Trait for base predictors in the OCR pipeline.
124///
125/// This trait defines the interface for base predictors that process batch data.
126pub trait BasePredictor {
127    /// The result type of this predictor.
128    type Result;
129
130    /// The error type of this predictor.
131    type Error;
132
133    /// Processes the given batch data.
134    ///
135    /// # Arguments
136    ///
137    /// * `batch_data` - The batch data to process.
138    ///
139    /// # Returns
140    ///
141    /// A Result containing the processing result or an error.
142    fn process(&self, batch_data: BatchData) -> Result<Self::Result, Self::Error>;
143
144    /// Converts the processing result to a prediction result.
145    ///
146    /// # Arguments
147    ///
148    /// * `result` - The processing result to convert.
149    ///
150    /// # Returns
151    ///
152    /// The converted prediction result.
153    fn convert_to_prediction_result(&self, result: Self::Result) -> PredictionResult<'static>;
154
155    /// Gets the batch sampler used by this predictor.
156    ///
157    /// # Returns
158    ///
159    /// A reference to the batch sampler.
160    fn batch_sampler(&self) -> &BatchSampler;
161
162    /// Gets the name of the model used by this predictor.
163    ///
164    /// # Returns
165    ///
166    /// The name of the model.
167    fn model_name(&self) -> &str;
168
169    /// Gets the type name of this predictor.
170    ///
171    /// # Returns
172    ///
173    /// The type name of the predictor.
174    fn predictor_type_name(&self) -> &str;
175}
176
177/// Trait for reading images.
178///
179/// This trait defines the interface for reading images from paths.
180pub trait ImageReader {
181    /// The error type of this image reader.
182    type Error;
183
184    /// Applies the image reader to the given paths.
185    ///
186    /// # Arguments
187    ///
188    /// * `imgs` - An iterator of paths to the images to read.
189    ///
190    /// # Returns
191    ///
192    /// A Result containing a vector of RGB images or an error.
193    ///
194    /// # Constraints
195    ///
196    /// * `P` must implement `AsRef<Path>` + Send + Sync.
197    fn apply<P: AsRef<Path> + Send + Sync>(
198        &self,
199        imgs: impl IntoIterator<Item = P>,
200    ) -> Result<Vec<RgbImage>, Self::Error>;
201
202    /// Reads a single image from the given path.
203    ///
204    /// # Arguments
205    ///
206    /// * `img_path` - The path to the image to read.
207    ///
208    /// # Returns
209    ///
210    /// A Result containing the RGB image or an error.
211    ///
212    /// # Constraints
213    ///
214    /// * `P` must implement `AsRef<Path>` + Send + Sync.
215    fn read_single<P: AsRef<Path> + Send + Sync>(
216        &self,
217        img_path: P,
218    ) -> Result<RgbImage, Self::Error>
219    where
220        Self::Error: From<OCRError>,
221    {
222        let mut results = self.apply(std::iter::once(img_path))?;
223        results.pop().ok_or_else(|| {
224            // Create a proper error instead of panicking
225            OCRError::invalid_input("ImageReader::apply returned empty result for single image")
226                .into()
227        })
228    }
229}
230
231/// Trait for predictor configurations.
232///
233/// This trait defines the interface for predictor configurations.
234pub trait PredictorConfig {
235    /// Gets the name of the model.
236    ///
237    /// # Returns
238    ///
239    /// The name of the model.
240    fn model_name(&self) -> &str;
241
242    /// Gets the batch size.
243    ///
244    /// # Returns
245    ///
246    /// The batch size.
247    fn batch_size(&self) -> usize;
248
249    /// Validates the configuration.
250    ///
251    /// # Returns
252    ///
253    /// A Result indicating success or an error.
254    fn validate(&self) -> crate::core::OcrResult<()>;
255
256    /// Validates the batch size.
257    ///
258    /// # Returns
259    ///
260    /// A Result indicating success or an error.
261    ///
262    /// # Validation Rules
263    ///
264    /// * Batch size must be greater than 0.
265    /// * Batch size should not exceed 1000 for memory efficiency.
266    fn validate_batch_size(&self) -> crate::core::OcrResult<()> {
267        let batch_size = self.batch_size();
268        if batch_size == 0 {
269            return Err(OCRError::ConfigError {
270                message: "Batch size must be greater than 0".to_string(),
271            });
272        }
273        if batch_size > 1000 {
274            return Err(OCRError::ConfigError {
275                message: "Batch size should not exceed 1000 for memory efficiency".to_string(),
276            });
277        }
278        Ok(())
279    }
280
281    /// Validates the model name.
282    ///
283    /// # Returns
284    ///
285    /// A Result indicating success or an error.
286    ///
287    /// # Validation Rules
288    ///
289    /// * Model name cannot be empty.
290    fn validate_model_name(&self) -> crate::core::OcrResult<()> {
291        let name = self.model_name();
292        if name.is_empty() {
293            return Err(OCRError::ConfigError {
294                message: "Model name cannot be empty".to_string(),
295            });
296        }
297        Ok(())
298    }
299}
300
301/// Trait for standard predictors.
302///
303/// This trait defines the interface for standard predictors that follow
304/// a standard pipeline of reading images, preprocessing, inference, and postprocessing.
305pub trait StandardPredictor {
306    /// The configuration type for this predictor.
307    type Config;
308
309    /// The result type of this predictor.
310    type Result;
311
312    /// The output type of the preprocessing step.
313    type PreprocessOutput;
314
315    /// The output type of the inference step.
316    type InferenceOutput;
317
318    /// Reads images from the given paths.
319    ///
320    /// # Arguments
321    ///
322    /// * `paths` - An iterator of paths to the images to read.
323    ///
324    /// # Returns
325    ///
326    /// A Result containing a vector of RGB images or an error.
327    fn read_images<'a>(
328        &self,
329        paths: impl Iterator<Item = &'a str>,
330    ) -> Result<Vec<RgbImage>, OCRError>;
331
332    /// Preprocesses the given images.
333    ///
334    /// # Arguments
335    ///
336    /// * `images` - The images to preprocess.
337    /// * `config` - Optional configuration for preprocessing.
338    ///
339    /// # Returns
340    ///
341    /// A Result containing the preprocessed output or an error.
342    fn preprocess(
343        &self,
344        images: Vec<RgbImage>,
345        config: Option<&Self::Config>,
346    ) -> crate::core::OcrResult<Self::PreprocessOutput>;
347
348    /// Performs inference on the preprocessed input.
349    ///
350    /// # Arguments
351    ///
352    /// * `input` - The preprocessed input.
353    ///
354    /// # Returns
355    ///
356    /// A Result containing the inference output or an error.
357    fn infer(
358        &self,
359        input: &Self::PreprocessOutput,
360    ) -> crate::core::OcrResult<Self::InferenceOutput>;
361
362    /// Postprocesses the inference output.
363    ///
364    /// # Arguments
365    ///
366    /// * `output` - The inference output to postprocess.
367    /// * `preprocessed` - The preprocessed input.
368    /// * `batch_data` - The batch data.
369    /// * `raw_images` - The raw images.
370    /// * `config` - Optional configuration for postprocessing.
371    ///
372    /// # Returns
373    ///
374    /// A Result containing the final result or an error.
375    fn postprocess(
376        &self,
377        output: Self::InferenceOutput,
378        preprocessed: &Self::PreprocessOutput,
379        batch_data: &BatchData,
380        raw_images: Vec<RgbImage>,
381        config: Option<&Self::Config>,
382    ) -> crate::core::OcrResult<Self::Result>;
383
384    /// Performs prediction directly from in-memory images.
385    ///
386    /// This method bypasses file I/O by working directly with RgbImage instances,
387    /// providing better performance when images are already in memory. This is
388    /// the primary prediction method for most use cases.
389    ///
390    /// # Arguments
391    ///
392    /// * `images` - Vector of images to process
393    /// * `config` - Optional configuration for the prediction
394    ///
395    /// # Returns
396    ///
397    /// A Result containing the prediction result or an OCRError
398    fn predict(
399        &self,
400        images: Vec<RgbImage>,
401        config: Option<Self::Config>,
402    ) -> crate::core::OcrResult<Self::Result> {
403        if images.is_empty() {
404            return self.empty_result();
405        }
406
407        let batch_data = self.create_dummy_batch_data(images.len());
408        let preprocessed = self.preprocess(images.clone(), config.as_ref())?;
409        let inference_output = self.infer(&preprocessed)?;
410        self.postprocess(
411            inference_output,
412            &preprocessed,
413            &batch_data,
414            images,
415            config.as_ref(),
416        )
417    }
418
419    /// Creates dummy batch data for in-memory processing.
420    ///
421    /// This method creates BatchData with dummy paths for in-memory processing,
422    /// allowing the postprocessing step to work correctly without actual file paths.
423    ///
424    /// # Arguments
425    ///
426    /// * `count` - Number of images to create batch data for
427    ///
428    /// # Returns
429    ///
430    /// BatchData with dummy paths and sequential indexes
431    fn create_dummy_batch_data(&self, count: usize) -> BatchData {
432        let dummy_paths: Vec<Arc<str>> = (0..count)
433            .map(|i| Arc::from(format!("in_memory_{i}")))
434            .collect();
435        BatchData {
436            instances: dummy_paths.clone(),
437            input_paths: dummy_paths,
438            indexes: (0..count).collect(),
439        }
440    }
441
442    /// Returns an empty result for the predictor type.
443    ///
444    /// This method should return an empty instance of the result type,
445    /// typically used when processing an empty list of images.
446    ///
447    /// # Returns
448    ///
449    /// A Result containing an empty result instance
450    fn empty_result(&self) -> crate::core::OcrResult<Self::Result>;
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use crate::core::OCRError;
457    use image::RgbImage;
458    use std::path::Path;
459
460    /// Mock ImageReader that always returns empty results to test error handling
461    struct MockEmptyImageReader;
462
463    impl ImageReader for MockEmptyImageReader {
464        type Error = OCRError;
465
466        fn apply<P: AsRef<Path> + Send + Sync>(
467            &self,
468            _imgs: impl IntoIterator<Item = P>,
469        ) -> Result<Vec<RgbImage>, Self::Error> {
470            // Always return empty vector to trigger the error condition
471            Ok(Vec::new())
472        }
473    }
474
475    #[test]
476    fn test_read_single_handles_empty_result_properly() {
477        let reader = MockEmptyImageReader;
478        let result = reader.read_single("test_path.jpg");
479
480        // Should return an error instead of panicking
481        assert!(result.is_err());
482
483        // Check that it's the correct error type
484        match result.unwrap_err() {
485            OCRError::InvalidInput { message } => {
486                assert!(
487                    message.contains("ImageReader::apply returned empty result for single image")
488                );
489            }
490            _ => panic!("Expected InvalidInput error"),
491        }
492    }
493}