oar_ocr/pipeline/oarocr/
config.rs1use crate::core::DynamicBatchConfig;
4use crate::pipeline::stages::{OrientationConfig, TextLineOrientationConfig};
5use crate::predictor::{
6 DocOrientationClassifierConfig, DoctrRectifierPredictorConfig, TextDetPredictorConfig,
7 TextLineClasPredictorConfig, TextRecPredictorConfig,
8};
9use crate::processors::{AspectRatioBucketingConfig, LimitType};
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12
13pub use crate::core::config::{OnnxThreadingConfig, ParallelPolicy};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct OAROCRConfig {
17 #[serde(default)]
19 pub detection: TextDetPredictorConfig,
20
21 #[serde(default)]
23 pub recognition: TextRecPredictorConfig,
24
25 #[serde(default)]
27 pub orientation: Option<DocOrientationClassifierConfig>,
28
29 #[serde(default)]
31 pub rectification: Option<DoctrRectifierPredictorConfig>,
32
33 #[serde(default)]
35 pub text_line_orientation: Option<TextLineClasPredictorConfig>,
36
37 #[serde(default)]
39 pub orientation_stage: Option<OrientationConfig>,
40
41 #[serde(default)]
43 pub text_line_orientation_stage: Option<TextLineOrientationConfig>,
44
45 pub character_dict_path: PathBuf,
47
48 #[serde(default)]
50 pub use_doc_orientation_classify: bool,
51
52 #[serde(default)]
54 pub use_doc_unwarping: bool,
55
56 #[serde(default)]
58 pub use_textline_orientation: bool,
59
60 #[serde(default)]
63 pub aspect_ratio_bucketing: Option<AspectRatioBucketingConfig>,
64
65 #[serde(default)]
68 pub dynamic_batching: Option<DynamicBatchConfig>,
69
70 #[serde(default)]
72 pub parallel_policy: ParallelPolicy,
73}
74
75impl OAROCRConfig {
76 pub fn new(
91 text_detection_model_path: impl Into<PathBuf>,
92 text_recognition_model_path: impl Into<PathBuf>,
93 character_dict_path: impl Into<PathBuf>,
94 ) -> Self {
95 let mut detection_config = TextDetPredictorConfig::new();
96 detection_config.common.model_path = Some(text_detection_model_path.into());
97 detection_config.common.batch_size = Some(1);
98 detection_config.limit_side_len = Some(736);
99 detection_config.limit_type = Some(LimitType::Max);
100
101 let mut recognition_config = TextRecPredictorConfig::new();
102 recognition_config.common.model_path = Some(text_recognition_model_path.into());
103 recognition_config.common.batch_size = Some(1);
104
105 Self {
106 detection: detection_config,
107 recognition: recognition_config,
108 orientation: None,
109 rectification: None,
110 text_line_orientation: None,
111 orientation_stage: None,
112 text_line_orientation_stage: None,
113 character_dict_path: character_dict_path.into(),
114 use_doc_orientation_classify: false,
115 use_doc_unwarping: false,
116 use_textline_orientation: false,
117 aspect_ratio_bucketing: None,
118 dynamic_batching: None,
119 parallel_policy: ParallelPolicy::default(),
120 }
121 }
122
123 pub fn effective_parallel_policy(&self) -> ParallelPolicy {
125 self.parallel_policy.clone()
126 }
127
128 pub fn max_threads(&self) -> Option<usize> {
130 self.effective_parallel_policy().max_threads
131 }
132
133 pub fn image_threshold(&self) -> usize {
135 self.effective_parallel_policy().image_threshold
136 }
137
138 pub fn text_box_threshold(&self) -> usize {
140 self.effective_parallel_policy().text_box_threshold
141 }
142
143 pub fn batch_threshold(&self) -> usize {
145 self.effective_parallel_policy().batch_threshold
146 }
147
148 pub fn utility_threshold(&self) -> usize {
150 self.effective_parallel_policy().utility_threshold
151 }
152
153 pub fn postprocess_pixel_threshold(&self) -> usize {
155 self.effective_parallel_policy().postprocess_pixel_threshold
156 }
157
158 pub fn onnx_threading(&self) -> OnnxThreadingConfig {
160 self.effective_parallel_policy().onnx_threading
161 }
162}
163
164impl Default for OAROCRConfig {
170 fn default() -> Self {
171 Self::new(
172 "default_detection_model.onnx",
173 "default_recognition_model.onnx",
174 "default_char_dict.txt",
175 )
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_parallel_policy_builder() {
185 let onnx_config = OnnxThreadingConfig {
186 intra_threads: Some(4),
187 inter_threads: Some(2),
188 parallel_execution: Some(true),
189 };
190
191 let policy = ParallelPolicy::new()
192 .with_max_threads(Some(8))
193 .with_image_threshold(2)
194 .with_text_box_threshold(5)
195 .with_batch_threshold(20)
196 .with_utility_threshold(8)
197 .with_postprocess_pixel_threshold(16000)
198 .with_onnx_threading(onnx_config.clone());
199
200 assert_eq!(policy.max_threads, Some(8));
201 assert_eq!(policy.image_threshold, 2);
202 assert_eq!(policy.text_box_threshold, 5);
203 assert_eq!(policy.batch_threshold, 20);
204 assert_eq!(policy.utility_threshold, 8);
205 assert_eq!(policy.postprocess_pixel_threshold, 16000);
206 assert_eq!(policy.onnx_threading.intra_threads, Some(4));
207 assert_eq!(policy.onnx_threading.inter_threads, Some(2));
208 assert_eq!(policy.onnx_threading.parallel_execution, Some(true));
209 }
210
211 #[test]
212 fn test_parallel_policy_serialization() {
213 let policy = ParallelPolicy::new()
214 .with_max_threads(Some(4))
215 .with_image_threshold(3);
216
217 let serialized = serde_json::to_string(&policy).unwrap();
218 let deserialized: ParallelPolicy = serde_json::from_str(&serialized).unwrap();
219
220 assert_eq!(policy.max_threads, deserialized.max_threads);
221 assert_eq!(policy.image_threshold, deserialized.image_threshold);
222 assert_eq!(policy.text_box_threshold, deserialized.text_box_threshold);
223 assert_eq!(policy.batch_threshold, deserialized.batch_threshold);
224 assert_eq!(policy.utility_threshold, deserialized.utility_threshold);
225 }
226
227 #[test]
228 fn test_oarocr_config_effective_parallel_policy() {
229 let mut config = OAROCRConfig::default();
230
231 let policy = config.effective_parallel_policy();
233 assert_eq!(policy.max_threads, None);
234 assert_eq!(policy.image_threshold, 1);
235 assert_eq!(policy.text_box_threshold, 1);
236
237 config.parallel_policy = ParallelPolicy::new()
239 .with_max_threads(Some(6))
240 .with_image_threshold(3);
241
242 let policy = config.effective_parallel_policy();
243 assert_eq!(policy.max_threads, Some(6));
244 assert_eq!(policy.image_threshold, 3);
245 assert_eq!(policy.text_box_threshold, 1);
246 }
247
248 #[test]
249 fn test_oarocr_config_parallel_policy() {
250 let config = OAROCRConfig {
251 parallel_policy: ParallelPolicy::new()
252 .with_max_threads(Some(4))
253 .with_image_threshold(2),
254 ..Default::default()
255 };
256
257 let policy = config.effective_parallel_policy();
258 assert_eq!(policy.max_threads, Some(4));
259 assert_eq!(policy.image_threshold, 2);
260 assert_eq!(policy.text_box_threshold, 1); assert_eq!(config.max_threads(), Some(4));
264 assert_eq!(config.image_threshold(), 2);
265 assert_eq!(config.text_box_threshold(), 1);
266 assert_eq!(config.batch_threshold(), 10); assert_eq!(config.utility_threshold(), 4); }
269}