oar_ocr/pipeline/oarocr/
config.rs

1//! Configuration types for the OAROCR pipeline.
2
3use crate::core::DynamicBatchConfig;
4use crate::pipeline::stages::{OrientationConfig, TextLineOrientationConfig};
5use crate::predictor::{
6    DocOrientationClassifierConfig, DoctrRectifierPredictorConfig, TextDetPredictorConfig,
7    TextLineClasPredictorConfig, TextRecPredictorConfig,
8};
9use crate::processors::{AspectRatioBucketingConfig, LimitType};
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12
13pub use crate::core::config::{OnnxThreadingConfig, ParallelPolicy};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct OAROCRConfig {
17    /// Configuration for text detection.
18    #[serde(default)]
19    pub detection: TextDetPredictorConfig,
20
21    /// Configuration for text recognition.
22    #[serde(default)]
23    pub recognition: TextRecPredictorConfig,
24
25    /// Configuration for document orientation classification (optional).
26    #[serde(default)]
27    pub orientation: Option<DocOrientationClassifierConfig>,
28
29    /// Configuration for document rectification/unwarping (optional).
30    #[serde(default)]
31    pub rectification: Option<DoctrRectifierPredictorConfig>,
32
33    /// Configuration for text line orientation classification (optional).
34    #[serde(default)]
35    pub text_line_orientation: Option<TextLineClasPredictorConfig>,
36
37    /// Configuration for document orientation stage processing.
38    #[serde(default)]
39    pub orientation_stage: Option<OrientationConfig>,
40
41    /// Configuration for text line orientation stage processing.
42    #[serde(default)]
43    pub text_line_orientation_stage: Option<TextLineOrientationConfig>,
44
45    /// Path to the character dictionary file for text recognition.
46    pub character_dict_path: PathBuf,
47
48    /// Whether to use document orientation classification.
49    #[serde(default)]
50    pub use_doc_orientation_classify: bool,
51
52    /// Whether to use document unwarping.
53    #[serde(default)]
54    pub use_doc_unwarping: bool,
55
56    /// Whether to use text line orientation classification.
57    #[serde(default)]
58    pub use_textline_orientation: bool,
59
60    /// Configuration for aspect ratio bucketing in text recognition.
61    /// If None, falls back to exact dimension grouping.
62    #[serde(default)]
63    pub aspect_ratio_bucketing: Option<AspectRatioBucketingConfig>,
64
65    /// Configuration for dynamic batching across multiple images.
66    /// If None, uses default dynamic batching configuration.
67    #[serde(default)]
68    pub dynamic_batching: Option<DynamicBatchConfig>,
69
70    /// Centralized parallel processing policy configuration
71    #[serde(default)]
72    pub parallel_policy: ParallelPolicy,
73}
74
75impl OAROCRConfig {
76    /// Creates a new OAROCRConfig with the required parameters.
77    ///
78    /// This constructor initializes the configuration with default values
79    /// for optional parameters while requiring the essential model paths.
80    ///
81    /// # Arguments
82    ///
83    /// * `text_detection_model_path` - Path to the text detection model file
84    /// * `text_recognition_model_path` - Path to the text recognition model file
85    /// * `character_dict_path` - Path to the character dictionary file
86    ///
87    /// # Returns
88    ///
89    /// A new OAROCRConfig instance with default values
90    pub fn new(
91        text_detection_model_path: impl Into<PathBuf>,
92        text_recognition_model_path: impl Into<PathBuf>,
93        character_dict_path: impl Into<PathBuf>,
94    ) -> Self {
95        let mut detection_config = TextDetPredictorConfig::new();
96        detection_config.common.model_path = Some(text_detection_model_path.into());
97        detection_config.common.batch_size = Some(1);
98        detection_config.limit_side_len = Some(736);
99        detection_config.limit_type = Some(LimitType::Max);
100
101        let mut recognition_config = TextRecPredictorConfig::new();
102        recognition_config.common.model_path = Some(text_recognition_model_path.into());
103        recognition_config.common.batch_size = Some(1);
104
105        Self {
106            detection: detection_config,
107            recognition: recognition_config,
108            orientation: None,
109            rectification: None,
110            text_line_orientation: None,
111            orientation_stage: None,
112            text_line_orientation_stage: None,
113            character_dict_path: character_dict_path.into(),
114            use_doc_orientation_classify: false,
115            use_doc_unwarping: false,
116            use_textline_orientation: false,
117            aspect_ratio_bucketing: None,
118            dynamic_batching: None,
119            parallel_policy: ParallelPolicy::default(),
120        }
121    }
122
123    /// Get the effective parallel policy
124    pub fn effective_parallel_policy(&self) -> ParallelPolicy {
125        self.parallel_policy.clone()
126    }
127
128    /// Get the maximum number of threads for parallel processing
129    pub fn max_threads(&self) -> Option<usize> {
130        self.effective_parallel_policy().max_threads
131    }
132
133    /// Get the image processing threshold
134    pub fn image_threshold(&self) -> usize {
135        self.effective_parallel_policy().image_threshold
136    }
137
138    /// Get the text box processing threshold
139    pub fn text_box_threshold(&self) -> usize {
140        self.effective_parallel_policy().text_box_threshold
141    }
142
143    /// Get the batch processing threshold
144    pub fn batch_threshold(&self) -> usize {
145        self.effective_parallel_policy().batch_threshold
146    }
147
148    /// Get the utility operations threshold
149    pub fn utility_threshold(&self) -> usize {
150        self.effective_parallel_policy().utility_threshold
151    }
152
153    /// Get the postprocessing pixel threshold
154    pub fn postprocess_pixel_threshold(&self) -> usize {
155        self.effective_parallel_policy().postprocess_pixel_threshold
156    }
157
158    /// Get the ONNX threading configuration
159    pub fn onnx_threading(&self) -> OnnxThreadingConfig {
160        self.effective_parallel_policy().onnx_threading
161    }
162}
163
164/// Implementation of Default for OAROCRConfig.
165///
166/// This provides a default configuration that can be used for testing.
167/// Note: This default configuration will not work for actual OCR processing
168/// as it doesn't specify valid model paths.
169impl Default for OAROCRConfig {
170    fn default() -> Self {
171        Self::new(
172            "default_detection_model.onnx",
173            "default_recognition_model.onnx",
174            "default_char_dict.txt",
175        )
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_parallel_policy_builder() {
185        let onnx_config = OnnxThreadingConfig {
186            intra_threads: Some(4),
187            inter_threads: Some(2),
188            parallel_execution: Some(true),
189        };
190
191        let policy = ParallelPolicy::new()
192            .with_max_threads(Some(8))
193            .with_image_threshold(2)
194            .with_text_box_threshold(5)
195            .with_batch_threshold(20)
196            .with_utility_threshold(8)
197            .with_postprocess_pixel_threshold(16000)
198            .with_onnx_threading(onnx_config.clone());
199
200        assert_eq!(policy.max_threads, Some(8));
201        assert_eq!(policy.image_threshold, 2);
202        assert_eq!(policy.text_box_threshold, 5);
203        assert_eq!(policy.batch_threshold, 20);
204        assert_eq!(policy.utility_threshold, 8);
205        assert_eq!(policy.postprocess_pixel_threshold, 16000);
206        assert_eq!(policy.onnx_threading.intra_threads, Some(4));
207        assert_eq!(policy.onnx_threading.inter_threads, Some(2));
208        assert_eq!(policy.onnx_threading.parallel_execution, Some(true));
209    }
210
211    #[test]
212    fn test_parallel_policy_serialization() {
213        let policy = ParallelPolicy::new()
214            .with_max_threads(Some(4))
215            .with_image_threshold(3);
216
217        let serialized = serde_json::to_string(&policy).unwrap();
218        let deserialized: ParallelPolicy = serde_json::from_str(&serialized).unwrap();
219
220        assert_eq!(policy.max_threads, deserialized.max_threads);
221        assert_eq!(policy.image_threshold, deserialized.image_threshold);
222        assert_eq!(policy.text_box_threshold, deserialized.text_box_threshold);
223        assert_eq!(policy.batch_threshold, deserialized.batch_threshold);
224        assert_eq!(policy.utility_threshold, deserialized.utility_threshold);
225    }
226
227    #[test]
228    fn test_oarocr_config_effective_parallel_policy() {
229        let mut config = OAROCRConfig::default();
230
231        // Test with default policy
232        let policy = config.effective_parallel_policy();
233        assert_eq!(policy.max_threads, None);
234        assert_eq!(policy.image_threshold, 1);
235        assert_eq!(policy.text_box_threshold, 1);
236
237        // Test with custom parallel policy
238        config.parallel_policy = ParallelPolicy::new()
239            .with_max_threads(Some(6))
240            .with_image_threshold(3);
241
242        let policy = config.effective_parallel_policy();
243        assert_eq!(policy.max_threads, Some(6));
244        assert_eq!(policy.image_threshold, 3);
245        assert_eq!(policy.text_box_threshold, 1);
246    }
247
248    #[test]
249    fn test_oarocr_config_parallel_policy() {
250        let config = OAROCRConfig {
251            parallel_policy: ParallelPolicy::new()
252                .with_max_threads(Some(4))
253                .with_image_threshold(2),
254            ..Default::default()
255        };
256
257        let policy = config.effective_parallel_policy();
258        assert_eq!(policy.max_threads, Some(4));
259        assert_eq!(policy.image_threshold, 2);
260        assert_eq!(policy.text_box_threshold, 1); // Default
261
262        // Test convenience methods
263        assert_eq!(config.max_threads(), Some(4));
264        assert_eq!(config.image_threshold(), 2);
265        assert_eq!(config.text_box_threshold(), 1);
266        assert_eq!(config.batch_threshold(), 10); // Default
267        assert_eq!(config.utility_threshold(), 4); // Default
268    }
269}