oar_ocr/pipeline/oarocr/builder.rs
1//! Builder pattern implementation for the OAROCR pipeline.
2
3use crate::core::{DynamicBatchConfig, ShapeCompatibilityStrategy};
4use crate::pipeline::oarocr::config::OAROCRConfig;
5use crate::predictor::{
6 DocOrientationClassifierConfig, DoctrRectifierPredictorConfig, TextLineClasPredictorConfig,
7};
8use crate::processors::{AspectRatioBucketingConfig, LimitType};
9use crate::{impl_complete_builder, with_nested};
10use std::path::PathBuf;
11use tracing::warn;
12
13/// Builder for creating OAROCR instances.
14///
15/// This struct provides a fluent API for configuring and building
16/// OAROCR pipeline instances with various options.
17#[derive(Debug)]
18pub struct OAROCRBuilder {
19 config: OAROCRConfig,
20}
21
22impl OAROCRBuilder {
23 /// Validates and clamps a threshold value to the range [0.0, 1.0].
24 ///
25 /// # Arguments
26 ///
27 /// * `threshold` - The threshold value to validate
28 /// * `param_name` - The name of the parameter for logging purposes
29 ///
30 /// # Returns
31 ///
32 /// The validated and potentially clamped threshold value
33 fn validate_threshold(threshold: f32, param_name: &str) -> f32 {
34 if (0.0..=1.0).contains(&threshold) {
35 threshold
36 } else {
37 warn!("{param_name} out of range [{threshold}], clamping to [0.0, 1.0]");
38 threshold.clamp(0.0, 1.0)
39 }
40 }
41
42 /// Validates and ensures a size value is at least 1.
43 ///
44 /// # Arguments
45 ///
46 /// * `size` - The size value to validate
47 /// * `param_name` - The name of the parameter for logging purposes
48 ///
49 /// # Returns
50 ///
51 /// The validated size value (minimum 1)
52 fn validate_min_size_usize(size: usize, param_name: &str) -> usize {
53 if size >= 1 {
54 size
55 } else {
56 warn!("{param_name} must be >= 1, got {size}; using 1");
57 1
58 }
59 }
60
61 /// Validates and ensures a size value is at least 1.
62 ///
63 /// # Arguments
64 ///
65 /// * `size` - The size value to validate
66 /// * `param_name` - The name of the parameter for logging purposes
67 ///
68 /// # Returns
69 ///
70 /// The validated size value (minimum 1)
71 fn validate_min_size_u32(size: u32, param_name: &str) -> u32 {
72 if size >= 1 {
73 size
74 } else {
75 warn!("{param_name} must be >= 1, got {size}; using 1");
76 1
77 }
78 }
79
80 /// Validates and ensures a dimension value is greater than 0.
81 ///
82 /// # Arguments
83 ///
84 /// * `dimension` - The dimension value to validate
85 /// * `param_name` - The name of the parameter for logging purposes
86 ///
87 /// # Returns
88 ///
89 /// The validated dimension value (minimum 1)
90 fn validate_dimension(dimension: u32, param_name: &str) -> u32 {
91 if dimension > 0 {
92 dimension
93 } else {
94 warn!("{param_name} {} <= 0; using 1", dimension);
95 1u32
96 }
97 }
98
99 /// Validates and ensures a positive float value.
100 ///
101 /// # Arguments
102 ///
103 /// * `value` - The value to validate
104 /// * `param_name` - The name of the parameter for logging purposes
105 /// * `default` - The default value to use if validation fails
106 ///
107 /// # Returns
108 ///
109 /// The validated value or the default
110 fn validate_positive_f32(value: f32, param_name: &str, default: f32) -> f32 {
111 if value > 0.0 {
112 value
113 } else {
114 warn!("{param_name} must be > 0.0, got {value}; using {default}");
115 default
116 }
117 }
118
119 /// Validates and ensures a non-negative float value.
120 ///
121 /// # Arguments
122 ///
123 /// * `value` - The value to validate
124 /// * `param_name` - The name of the parameter for logging purposes
125 ///
126 /// # Returns
127 ///
128 /// The validated value (minimum 0.0)
129 fn validate_non_negative_f32(value: f32, param_name: &str) -> f32 {
130 if value >= 0.0 {
131 value
132 } else {
133 warn!("{param_name} must be >= 0.0, got {value}; using 0.0");
134 0.0
135 }
136 }
137
138 /// Creates a new OAROCRBuilder with the required parameters.
139 ///
140 /// # Arguments
141 ///
142 /// * `text_detection_model_path` - Path to the text detection model file
143 /// * `text_recognition_model_path` - Path to the text recognition model file
144 /// * `text_rec_character_dict_path` - Path to the character dictionary file
145 ///
146 /// # Returns
147 ///
148 /// A new OAROCRBuilder instance
149 pub fn new(
150 text_detection_model_path: String,
151 text_recognition_model_path: String,
152 text_rec_character_dict_path: String,
153 ) -> Self {
154 Self {
155 config: OAROCRConfig::new(
156 text_detection_model_path,
157 text_recognition_model_path,
158 text_rec_character_dict_path,
159 ),
160 }
161 }
162
163 /// Creates a new OAROCRBuilder from an existing configuration.
164 ///
165 /// # Arguments
166 ///
167 /// * `config` - The OAROCRConfig to use
168 ///
169 /// # Returns
170 ///
171 /// A new OAROCRBuilder instance
172 pub fn from_config(config: OAROCRConfig) -> Self {
173 Self { config }
174 }
175
176 /// Sets the document orientation classification model name.
177 ///
178 /// # Arguments
179 ///
180 /// * `name` - The model name
181 ///
182 /// # Returns
183 ///
184 /// The updated builder instance
185 pub fn doc_orientation_classify_model_name(mut self, name: String) -> Self {
186 with_nested!(self.config.orientation, DocOrientationClassifierConfig, config => {
187 config.common.model_name = Some(name);
188 });
189 self
190 }
191
192 /// Sets the document orientation classification model path.
193 ///
194 /// # Arguments
195 ///
196 /// * `path` - The path to the model file
197 ///
198 /// # Returns
199 ///
200 /// The updated builder instance
201 pub fn doc_orientation_classify_model_path(mut self, path: impl Into<PathBuf>) -> Self {
202 with_nested!(self.config.orientation, DocOrientationClassifierConfig, config => {
203 config.common.model_path = Some(path.into());
204 });
205 self
206 }
207
208 /// Sets the document orientation confidence threshold.
209 ///
210 /// Specifies the minimum confidence score required for orientation predictions.
211 /// If the confidence is below this threshold, the orientation may be treated
212 /// as uncertain and fall back to default behavior.
213 ///
214 /// # Arguments
215 ///
216 /// * `threshold` - Minimum confidence threshold (0.0 to 1.0)
217 ///
218 /// # Returns
219 /// The updated builder instance
220 /// Sets the confidence threshold for document orientation classification.
221 ///
222 /// This threshold determines the minimum confidence required for orientation predictions.
223 /// If the confidence is below this threshold, the orientation may be treated
224 /// as uncertain and fall back to default behavior.
225 ///
226 /// # Arguments
227 ///
228 /// * `threshold` - Minimum confidence threshold (0.0 to 1.0)
229 ///
230 /// # Returns
231 ///
232 /// The updated builder instance
233 pub fn doc_orientation_threshold(mut self, threshold: f32) -> Self {
234 let t = Self::validate_threshold(threshold, "doc_orientation_threshold");
235 if self.config.orientation_stage.is_none() {
236 self.config.orientation_stage =
237 Some(crate::pipeline::stages::OrientationConfig::default());
238 }
239 if let Some(ref mut config) = self.config.orientation_stage {
240 config.confidence_threshold = Some(t);
241 }
242 self
243 }
244
245 /// Sets the document unwarping model name.
246 ///
247 /// # Arguments
248 ///
249 /// * `name` - The model name
250 ///
251 /// # Returns
252 ///
253 /// The updated builder instance
254 pub fn doc_unwarping_model_name(mut self, name: String) -> Self {
255 with_nested!(self.config.rectification, DoctrRectifierPredictorConfig, config => {
256 config.common.model_name = Some(name);
257 });
258 self
259 }
260
261 /// Sets the document unwarping model path.
262 ///
263 /// # Arguments
264 ///
265 /// * `path` - The path to the model file
266 ///
267 /// # Returns
268 ///
269 /// The updated builder instance
270 pub fn doc_unwarping_model_path(mut self, path: impl Into<PathBuf>) -> Self {
271 with_nested!(self.config.rectification, DoctrRectifierPredictorConfig, config => {
272 config.common.model_path = Some(path.into());
273 });
274 self
275 }
276
277 /// Sets the text detection model name.
278 ///
279 /// # Arguments
280 ///
281 /// * `name` - The model name
282 ///
283 /// # Returns
284 ///
285 /// The updated builder instance
286 pub fn text_detection_model_name(mut self, name: String) -> Self {
287 self.config.detection.common.model_name = Some(name);
288 self
289 }
290
291 /// Sets the text detection model path.
292 ///
293 /// # Arguments
294 ///
295 /// * `path` - The path to the model file
296 ///
297 /// # Returns
298 ///
299 /// The updated builder instance
300 pub fn text_detection_model_path(mut self, path: impl Into<PathBuf>) -> Self {
301 self.config.detection.common.model_path = Some(path.into());
302 self
303 }
304
305 /// Sets the text detection batch size.
306 ///
307 /// # Arguments
308 ///
309 /// * `batch_size` - The batch size for inference
310 ///
311 /// # Returns
312 ///
313 /// The updated builder instance
314 pub fn text_detection_batch_size(mut self, batch_size: usize) -> Self {
315 let bs = Self::validate_min_size_usize(batch_size, "text_detection_batch_size");
316 self.config.detection.common.batch_size = Some(bs);
317 self
318 }
319
320 /// Sets the text recognition model name.
321 ///
322 /// # Arguments
323 ///
324 /// * `name` - The model name
325 ///
326 /// # Returns
327 ///
328 /// The updated builder instance
329 pub fn text_recognition_model_name(mut self, name: String) -> Self {
330 self.config.recognition.common.model_name = Some(name);
331 self
332 }
333
334 /// Sets the text recognition model path.
335 ///
336 /// # Arguments
337 ///
338 /// * `path` - The path to the model file
339 ///
340 /// # Returns
341 ///
342 /// The updated builder instance
343 pub fn text_recognition_model_path(mut self, path: impl Into<PathBuf>) -> Self {
344 self.config.recognition.common.model_path = Some(path.into());
345 self
346 }
347
348 /// Sets the text recognition batch size.
349 ///
350 /// # Arguments
351 ///
352 /// * `batch_size` - The batch size for inference
353 ///
354 /// # Returns
355 ///
356 /// The updated builder instance
357 pub fn text_recognition_batch_size(mut self, batch_size: usize) -> Self {
358 let bs = Self::validate_min_size_usize(batch_size, "text_recognition_batch_size");
359 self.config.recognition.common.batch_size = Some(bs);
360 self
361 }
362
363 /// Sets the text line orientation classification model name.
364 ///
365 /// # Arguments
366 ///
367 /// * `name` - The model name
368 ///
369 /// # Returns
370 ///
371 /// The updated builder instance
372 pub fn textline_orientation_classify_model_name(mut self, name: String) -> Self {
373 with_nested!(self.config.text_line_orientation, TextLineClasPredictorConfig, config => {
374 config.common.model_name = Some(name);
375 });
376 self
377 }
378
379 /// Sets the text line orientation classification model path.
380 ///
381 /// # Arguments
382 ///
383 /// * `path` - The path to the model file
384 ///
385 /// # Returns
386 ///
387 /// The updated builder instance
388 pub fn textline_orientation_classify_model_path(mut self, path: impl Into<PathBuf>) -> Self {
389 with_nested!(self.config.text_line_orientation, TextLineClasPredictorConfig, config => {
390 config.common.model_path = Some(path.into());
391 });
392 self
393 }
394
395 /// Sets the text line orientation classification batch size.
396 ///
397 /// # Arguments
398 ///
399 /// * `batch_size` - The batch size for inference
400 ///
401 /// # Returns
402 ///
403 /// The updated builder instance
404 pub fn textline_orientation_classify_batch_size(mut self, batch_size: usize) -> Self {
405 let bs =
406 Self::validate_min_size_usize(batch_size, "textline_orientation_classify_batch_size");
407 with_nested!(self.config.text_line_orientation, TextLineClasPredictorConfig, config => {
408 config.common.batch_size = Some(bs);
409 });
410 self
411 }
412
413 /// Sets the text line orientation classification input shape.
414 ///
415 /// # Arguments
416 ///
417 /// * `shape` - The input shape as (width, height)
418 ///
419 /// # Returns
420 ///
421 /// The updated builder instance
422 ///
423 /// Sets the text line orientation classifier input shape.
424 ///
425 /// # Arguments
426 ///
427 /// * `shape` - The input shape as (width, height)
428 ///
429 /// # Returns
430 ///
431 /// The updated builder instance
432 pub fn textline_orientation_input_shape(mut self, shape: (u32, u32)) -> Self {
433 let w = Self::validate_dimension(shape.0, "textline_orientation_input_shape width");
434 let h = Self::validate_dimension(shape.1, "textline_orientation_input_shape height");
435 with_nested!(self.config.text_line_orientation, TextLineClasPredictorConfig, config => {
436 config.input_shape = Some((w, h));
437 });
438 self
439 }
440
441 /// Sets the text line orientation confidence threshold.
442 ///
443 /// Specifies the minimum confidence score required for text line orientation predictions.
444 /// If the confidence is below this threshold, the orientation may be treated
445 /// as uncertain and fall back to default behavior.
446 ///
447 /// # Arguments
448 ///
449 /// * `threshold` - Minimum confidence threshold (0.0 to 1.0)
450 ///
451 /// # Returns
452 ///
453 /// The updated builder instance
454 ///
455 /// Sets the text line orientation confidence threshold.
456 ///
457 /// Specifies the minimum confidence score required for text line orientation predictions.
458 /// If the confidence is below this threshold, the orientation may be treated
459 /// as uncertain and fall back to default behavior.
460 ///
461 /// # Arguments
462 ///
463 /// * `threshold` - Minimum confidence threshold (0.0 to 1.0)
464 ///
465 /// # Returns
466 ///
467 /// The updated builder instance
468 pub fn textline_orientation_threshold(mut self, threshold: f32) -> Self {
469 let t = Self::validate_threshold(threshold, "textline_orientation_threshold");
470 if self.config.text_line_orientation_stage.is_none() {
471 self.config.text_line_orientation_stage =
472 Some(crate::pipeline::stages::TextLineOrientationConfig::default());
473 }
474 if let Some(ref mut config) = self.config.text_line_orientation_stage {
475 config.confidence_threshold = Some(t);
476 }
477 self
478 }
479
480 /// Sets whether to use document orientation classification.
481 ///
482 /// # Arguments
483 ///
484 /// * `use_it` - Whether to use document orientation classification
485 ///
486 /// # Returns
487 ///
488 /// The updated builder instance
489 pub fn use_doc_orientation_classify(mut self, use_it: bool) -> Self {
490 self.config.use_doc_orientation_classify = use_it;
491 self
492 }
493
494 /// Sets whether to use document unwarping.
495 ///
496 /// # Arguments
497 ///
498 /// * `use_it` - Whether to use document unwarping
499 ///
500 /// # Returns
501 ///
502 /// The updated builder instance
503 pub fn use_doc_unwarping(mut self, use_it: bool) -> Self {
504 self.config.use_doc_unwarping = use_it;
505 self
506 }
507
508 /// Sets whether to use text line orientation classification.
509 ///
510 /// # Arguments
511 ///
512 /// * `use_it` - Whether to use text line orientation classification
513 ///
514 /// # Returns
515 ///
516 /// The updated builder instance
517 pub fn use_textline_orientation(mut self, use_it: bool) -> Self {
518 self.config.use_textline_orientation = use_it;
519 self
520 }
521
522 /// Sets the parallel processing policy for the pipeline.
523 ///
524 /// # Arguments
525 ///
526 /// * `policy` - The parallel processing policy configuration
527 ///
528 /// # Returns
529 ///
530 /// The updated builder instance
531 pub fn parallel_policy(mut self, policy: super::config::ParallelPolicy) -> Self {
532 self.config.parallel_policy = policy;
533 self
534 }
535
536 /// Sets the ONNX Runtime session configuration for text detection.
537 ///
538 /// # Arguments
539 ///
540 /// * `config` - The ONNX Runtime session configuration
541 ///
542 /// # Returns
543 ///
544 /// The updated builder instance
545 pub fn text_detection_ort_session(
546 mut self,
547 config: crate::core::config::onnx::OrtSessionConfig,
548 ) -> Self {
549 self.config.detection.common.ort_session = Some(config);
550 self
551 }
552
553 /// Sets the ONNX Runtime session configuration for text recognition.
554 ///
555 /// # Arguments
556 ///
557 /// * `config` - The ONNX Runtime session configuration
558 ///
559 /// # Returns
560 ///
561 /// The updated builder instance
562 pub fn text_recognition_ort_session(
563 mut self,
564 config: crate::core::config::onnx::OrtSessionConfig,
565 ) -> Self {
566 self.config.recognition.common.ort_session = Some(config);
567 self
568 }
569
570 /// Sets the ONNX Runtime session configuration for text line orientation classification.
571 ///
572 /// # Arguments
573 ///
574 /// * `config` - The ONNX Runtime session configuration
575 ///
576 /// # Returns
577 ///
578 /// The updated builder instance
579 pub fn textline_orientation_ort_session(
580 mut self,
581 config: crate::core::config::onnx::OrtSessionConfig,
582 ) -> Self {
583 with_nested!(self.config.text_line_orientation, TextLineClasPredictorConfig, tlo_config => {
584 tlo_config.common.ort_session = Some(config);
585 });
586 self
587 }
588
589 /// Sets the ONNX Runtime session configuration for all components.
590 ///
591 /// This is a convenience method that applies the same ONNX session configuration
592 /// to text detection, text recognition, and text line orientation classification.
593 ///
594 /// # Arguments
595 ///
596 /// * `config` - The ONNX Runtime session configuration
597 ///
598 /// # Returns
599 ///
600 /// The updated builder instance
601 pub fn global_ort_session(
602 mut self,
603 config: crate::core::config::onnx::OrtSessionConfig,
604 ) -> Self {
605 // Apply to text detection
606 self.config.detection.common.ort_session = Some(config.clone());
607
608 // Apply to text recognition
609 self.config.recognition.common.ort_session = Some(config.clone());
610
611 // Apply to text line orientation (if configured)
612 with_nested!(self.config.text_line_orientation, TextLineClasPredictorConfig, tlo_config => {
613 tlo_config.common.ort_session = Some(config.clone());
614 });
615
616 self
617 }
618
619 /// Convenience method to enable CUDA execution with default settings.
620 ///
621 /// This configures CUDA execution provider with sensible defaults:
622 /// - Uses GPU device 0
623 /// - Falls back to CPU if CUDA fails
624 /// - Uses default memory and performance settings
625 ///
626 /// # Returns
627 ///
628 /// The updated builder instance
629 ///
630 /// # Example
631 ///
632 /// ```rust,no_run
633 /// use oar_ocr::pipeline::OAROCRBuilder;
634 ///
635 /// let builder = OAROCRBuilder::new(
636 /// "detection.onnx".to_string(),
637 /// "recognition.onnx".to_string(),
638 /// "dict.txt".to_string()
639 /// )
640 /// .with_cuda(); // Simple CUDA setup
641 /// ```
642 #[cfg(feature = "cuda")]
643 pub fn with_cuda(self) -> Self {
644 self.with_cuda_device(0)
645 }
646
647 /// Convenience method to enable CUDA execution on a specific GPU device.
648 ///
649 /// This configures CUDA execution provider with:
650 /// - Specified GPU device ID
651 /// - Falls back to CPU if CUDA fails
652 /// - Uses default memory and performance settings
653 ///
654 /// # Arguments
655 ///
656 /// * `device_id` - The GPU device ID to use (0, 1, 2, etc.)
657 ///
658 /// # Returns
659 ///
660 /// The updated builder instance
661 ///
662 /// # Example
663 ///
664 /// ```rust,no_run
665 /// use oar_ocr::pipeline::OAROCRBuilder;
666 ///
667 /// let builder = OAROCRBuilder::new(
668 /// "detection.onnx".to_string(),
669 /// "recognition.onnx".to_string(),
670 /// "dict.txt".to_string()
671 /// )
672 /// .with_cuda_device(1); // Use GPU 1
673 /// ```
674 #[cfg(feature = "cuda")]
675 pub fn with_cuda_device(self, device_id: u32) -> Self {
676 use crate::core::config::{OrtExecutionProvider, OrtSessionConfig};
677
678 let ort_config = OrtSessionConfig::new().with_execution_providers(vec![
679 OrtExecutionProvider::CUDA {
680 device_id: Some(device_id as i32),
681 gpu_mem_limit: None,
682 arena_extend_strategy: None,
683 cudnn_conv_algo_search: None,
684 do_copy_in_default_stream: None,
685 cudnn_conv_use_max_workspace: None,
686 },
687 OrtExecutionProvider::CPU, // Fallback to CPU
688 ]);
689
690 self.global_ort_session(ort_config)
691 }
692
693 /// Convenience method to enable high-performance processing configuration.
694 ///
695 /// This applies optimizations for batch processing:
696 /// - Increases parallel processing thresholds
697 /// - Optimizes memory usage
698 /// - Configures efficient batching strategies
699 ///
700 /// # Returns
701 ///
702 /// The updated builder instance
703 ///
704 /// # Example
705 ///
706 /// ```rust,no_run
707 /// use oar_ocr::pipeline::OAROCRBuilder;
708 ///
709 /// let builder = OAROCRBuilder::new(
710 /// "detection.onnx".to_string(),
711 /// "recognition.onnx".to_string(),
712 /// "dict.txt".to_string()
713 /// )
714 /// .with_high_performance(); // Optimize for batch processing
715 /// ```
716 pub fn with_high_performance(self) -> Self {
717 let max_threads = std::thread::available_parallelism()
718 .map(|n| n.get())
719 .unwrap_or(4); // Fallback to 4 threads if detection fails
720
721 self.parallel_policy(
722 super::config::ParallelPolicy::new()
723 .with_max_threads(Some(max_threads))
724 .with_image_threshold(2)
725 .with_text_box_threshold(5)
726 .with_batch_threshold(3),
727 )
728 }
729
730 /// Convenience method for mobile/resource-constrained environments.
731 ///
732 /// This configures the pipeline for minimal resource usage:
733 /// - Reduces parallel processing to avoid overwhelming the system
734 /// - Uses conservative memory settings
735 /// - Prioritizes stability over speed
736 ///
737 /// # Returns
738 ///
739 /// The updated builder instance
740 ///
741 /// # Example
742 ///
743 /// ```rust,no_run
744 /// use oar_ocr::pipeline::OAROCRBuilder;
745 ///
746 /// let builder = OAROCRBuilder::new(
747 /// "detection.onnx".to_string(),
748 /// "recognition.onnx".to_string(),
749 /// "dict.txt".to_string()
750 /// )
751 /// .with_low_resource(); // Optimize for limited resources
752 /// ```
753 pub fn with_low_resource(self) -> Self {
754 self.parallel_policy(
755 super::config::ParallelPolicy::new()
756 .with_max_threads(Some(2)) // Limit threads
757 .with_image_threshold(10) // Higher thresholds for parallel processing
758 .with_text_box_threshold(20)
759 .with_batch_threshold(10),
760 )
761 }
762
763 /// Sets the session pool size for text detection.
764 ///
765 /// # Arguments
766 ///
767 /// * `size` - The session pool size (minimum 1)
768 ///
769 /// # Returns
770 ///
771 /// The updated builder instance
772 pub fn text_detection_session_pool_size(mut self, size: usize) -> Self {
773 let s = Self::validate_min_size_usize(size, "text_detection_session_pool_size");
774 self.config.detection.common.session_pool_size = Some(s);
775 self
776 }
777
778 /// Sets the session pool size for text recognition.
779 ///
780 /// # Arguments
781 ///
782 /// * `size` - The session pool size (minimum 1)
783 ///
784 /// # Returns
785 ///
786 /// The updated builder instance
787 pub fn text_recognition_session_pool_size(mut self, size: usize) -> Self {
788 let s = Self::validate_min_size_usize(size, "text_recognition_session_pool_size");
789 self.config.recognition.common.session_pool_size = Some(s);
790 self
791 }
792
793 /// Sets the session pool size for text line orientation classification.
794 ///
795 /// # Arguments
796 ///
797 /// * `size` - The session pool size (minimum 1)
798 ///
799 /// # Returns
800 ///
801 /// The updated builder instance
802 pub fn textline_orientation_session_pool_size(mut self, size: usize) -> Self {
803 let s = Self::validate_min_size_usize(size, "textline_orientation_session_pool_size");
804 with_nested!(self.config.text_line_orientation, TextLineClasPredictorConfig, tlo_config => {
805 tlo_config.common.session_pool_size = Some(s);
806 });
807 self
808 }
809
810 /// Sets the session pool size for all components.
811 ///
812 /// This is a convenience method that applies the same session pool size
813 /// to text detection, text recognition, and text line orientation classification.
814 ///
815 /// # Arguments
816 ///
817 /// * `size` - The session pool size (minimum 1)
818 ///
819 /// # Returns
820 ///
821 /// The updated builder instance
822 pub fn global_session_pool_size(mut self, size: usize) -> Self {
823 let s = Self::validate_min_size_usize(size, "global_session_pool_size");
824 // Apply to text detection
825 self.config.detection.common.session_pool_size = Some(s);
826
827 // Apply to text recognition
828 self.config.recognition.common.session_pool_size = Some(s);
829
830 // Apply to text line orientation (if configured)
831 with_nested!(self.config.text_line_orientation, TextLineClasPredictorConfig, tlo_config => {
832 tlo_config.common.session_pool_size = Some(s);
833 });
834
835 self
836 }
837
838 /// Sets the text detection limit side length.
839 ///
840 /// # Arguments
841 ///
842 /// * `limit` - The maximum side length for resizing
843 ///
844 /// # Returns
845 ///
846 /// The updated builder instance
847 pub fn text_det_limit_side_len(mut self, limit: u32) -> Self {
848 let l = Self::validate_min_size_u32(limit, "text_det_limit_side_len");
849 self.config.detection.limit_side_len = Some(l);
850 self
851 }
852
853 /// Sets the text detection limit type.
854 ///
855 /// # Arguments
856 ///
857 /// * `limit_type` - The type of limit for resizing
858 ///
859 /// # Returns
860 ///
861 /// The updated builder instance
862 pub fn text_det_limit_type(mut self, limit_type: LimitType) -> Self {
863 self.config.detection.limit_type = Some(limit_type);
864 self
865 }
866
867 /// Sets the text detection input shape.
868 ///
869 /// # Arguments
870 ///
871 /// * `shape` - The input shape as (channels, height, width)
872 ///
873 /// # Returns
874 ///
875 /// The updated builder instance
876 pub fn text_det_input_shape(mut self, shape: (u32, u32, u32)) -> Self {
877 let c = Self::validate_dimension(shape.0, "text_det_input_shape channels");
878 let h = Self::validate_dimension(shape.1, "text_det_input_shape height");
879 let w = Self::validate_dimension(shape.2, "text_det_input_shape width");
880 self.config.detection.input_shape = Some((c, h, w));
881 self
882 }
883
884 /// Sets the text detection maximum side limit.
885 ///
886 /// This controls the maximum allowed size for any side of an image during
887 /// text detection preprocessing. Images larger than this limit will be resized
888 /// to fit within the constraint while maintaining aspect ratio.
889 ///
890 /// # Arguments
891 ///
892 /// * `max_side_limit` - The maximum side limit for image processing (default: 4000)
893 ///
894 /// # Returns
895 ///
896 /// The updated builder instance
897 ///
898 /// # Example
899 ///
900 /// ```rust,no_run
901 /// use oar_ocr::pipeline::OAROCRBuilder;
902 ///
903 /// let builder = OAROCRBuilder::new(
904 /// "detection_model.onnx".to_string(),
905 /// "recognition_model.onnx".to_string(),
906 /// "char_dict.txt".to_string()
907 /// )
908 /// .text_det_max_side_limit(5000); // Allow larger images
909 /// ```
910 pub fn text_det_max_side_limit(mut self, max_side_limit: u32) -> Self {
911 let m = Self::validate_min_size_u32(max_side_limit, "text_det_max_side_limit");
912 self.config.detection.max_side_limit = Some(m);
913 self
914 }
915
916 /// Sets the text detection binarization threshold.
917 ///
918 /// This controls the threshold used for binarizing the detection output.
919 /// Lower values may detect more text but with more false positives.
920 ///
921 /// # Arguments
922 ///
923 /// * `thresh` - The binarization threshold (default: 0.3)
924 ///
925 /// # Returns
926 ///
927 /// The updated builder instance
928 ///
929 /// # Example
930 ///
931 /// ```rust,no_run
932 /// use oar_ocr::pipeline::OAROCRBuilder;
933 ///
934 /// let builder = OAROCRBuilder::new(
935 /// "detection_model.onnx".to_string(),
936 /// "recognition_model.onnx".to_string(),
937 /// "char_dict.txt".to_string()
938 /// )
939 /// .text_det_threshold(0.4); // Higher threshold for more precise detection
940 /// ```
941 ///
942 /// Sets the text detection binarization threshold.
943 ///
944 /// This controls the threshold used for binarizing the detection output.
945 /// Lower values may detect more text but with more false positives.
946 ///
947 /// # Arguments
948 ///
949 /// * `threshold` - The binarization threshold (default: 0.3)
950 ///
951 /// # Returns
952 ///
953 /// The updated builder instance
954 pub fn text_det_threshold(mut self, threshold: f32) -> Self {
955 let t = Self::validate_threshold(threshold, "text_det_threshold");
956 self.config.detection.thresh = Some(t);
957 self
958 }
959
960 /// Sets the text detection box score threshold.
961 ///
962 /// This controls the threshold for filtering text boxes based on their confidence scores.
963 /// Higher values will filter out more uncertain detections.
964 ///
965 /// # Arguments
966 ///
967 /// * `box_thresh` - The box score threshold (default: 0.6)
968 ///
969 /// # Returns
970 ///
971 /// The updated builder instance
972 ///
973 /// # Example
974 ///
975 /// ```rust,no_run
976 /// use oar_ocr::pipeline::OAROCRBuilder;
977 ///
978 /// let builder = OAROCRBuilder::new(
979 /// "detection_model.onnx".to_string(),
980 /// "recognition_model.onnx".to_string(),
981 /// "char_dict.txt".to_string()
982 /// )
983 /// .text_det_box_threshold(0.7); // Higher threshold for more confident boxes
984 /// ```
985 ///
986 /// Sets the text detection box score threshold.
987 ///
988 /// This controls the threshold for filtering text boxes based on their confidence scores.
989 /// Higher values will filter out more uncertain detections.
990 ///
991 /// # Arguments
992 ///
993 /// * `threshold` - The box score threshold (default: 0.6)
994 ///
995 /// # Returns
996 ///
997 /// The updated builder instance
998 pub fn text_det_box_threshold(mut self, threshold: f32) -> Self {
999 let t = Self::validate_threshold(threshold, "text_det_box_threshold");
1000 self.config.detection.box_thresh = Some(t);
1001 self
1002 }
1003
1004 /// Sets the text detection unclip ratio.
1005 ///
1006 /// This controls how much to expand detected text boxes. Higher values
1007 /// will expand boxes more, potentially capturing more complete text.
1008 ///
1009 /// # Arguments
1010 ///
1011 /// * `unclip_ratio` - The unclip ratio for expanding text boxes (default: 1.5)
1012 ///
1013 /// # Returns
1014 ///
1015 /// The updated builder instance
1016 ///
1017 /// # Example
1018 ///
1019 /// ```rust,no_run
1020 /// use oar_ocr::pipeline::OAROCRBuilder;
1021 ///
1022 /// let builder = OAROCRBuilder::new(
1023 /// "detection_model.onnx".to_string(),
1024 /// "recognition_model.onnx".to_string(),
1025 /// "char_dict.txt".to_string()
1026 /// )
1027 /// .text_det_unclip_ratio(2.0); // More expansion for better text capture
1028 /// ```
1029 pub fn text_det_unclip_ratio(mut self, unclip_ratio: f32) -> Self {
1030 let r = Self::validate_positive_f32(unclip_ratio, "text_det_unclip_ratio", 1.0);
1031 self.config.detection.unclip_ratio = Some(r);
1032 self
1033 }
1034
1035 /// Sets the text recognition score threshold.
1036 ///
1037 /// # Arguments
1038 ///
1039 /// * `thresh` - The minimum score threshold for recognition results
1040 ///
1041 /// # Returns
1042 ///
1043 /// The updated builder instance
1044 ///
1045 /// Sets the text recognition score threshold.
1046 ///
1047 /// Results with confidence scores below this threshold will be filtered out.
1048 ///
1049 /// # Arguments
1050 ///
1051 /// * `threshold` - The minimum score threshold for recognition results
1052 ///
1053 /// # Returns
1054 ///
1055 /// The updated builder instance
1056 pub fn text_rec_score_threshold(mut self, threshold: f32) -> Self {
1057 let t = Self::validate_threshold(threshold, "text_rec_score_threshold");
1058 self.config.recognition.score_thresh = Some(t);
1059 self
1060 }
1061
1062 /// Sets the text recognition model input shape.
1063 ///
1064 /// # Arguments
1065 ///
1066 /// * `shape` - The model input shape as (channels, height, width)
1067 ///
1068 /// # Returns
1069 ///
1070 /// The updated builder instance
1071 ///
1072 /// Sets the text recognition model input shape.
1073 ///
1074 /// # Arguments
1075 ///
1076 /// * `shape` - The model input shape as (channels, height, width)
1077 ///
1078 /// # Returns
1079 ///
1080 /// The updated builder instance
1081 pub fn text_rec_input_shape(mut self, shape: (u32, u32, u32)) -> Self {
1082 let c = Self::validate_dimension(shape.0, "text_rec_input_shape channels");
1083 let h = Self::validate_dimension(shape.1, "text_rec_input_shape height");
1084 let w = Self::validate_dimension(shape.2, "text_rec_input_shape width");
1085 self.config.recognition.model_input_shape = Some([c as usize, h as usize, w as usize]);
1086 self
1087 }
1088
1089 /// Sets the text recognition character dictionary path.
1090 ///
1091 /// # Arguments
1092 ///
1093 /// * `path` - The path to the character dictionary file
1094 ///
1095 /// # Returns
1096 ///
1097 /// The updated builder instance
1098 pub fn text_rec_character_dict_path(mut self, path: impl Into<PathBuf>) -> Self {
1099 self.config.character_dict_path = path.into();
1100 self
1101 }
1102
1103 /// Enables aspect ratio bucketing for text recognition with default configuration.
1104 ///
1105 /// This enables aspect ratio bucketing which groups images by aspect ratio ranges
1106 /// instead of exact dimensions, improving batch efficiency.
1107 ///
1108 /// # Returns
1109 ///
1110 /// The updated builder instance
1111 ///
1112 /// Sets a custom aspect ratio bucketing configuration.
1113 ///
1114 /// # Arguments
1115 ///
1116 /// * `config` - The aspect ratio bucketing configuration
1117 ///
1118 /// # Returns
1119 ///
1120 /// The updated builder instance
1121 pub fn aspect_ratio_bucketing_config(mut self, config: AspectRatioBucketingConfig) -> Self {
1122 self.config.aspect_ratio_bucketing = Some(config);
1123 self
1124 }
1125
1126 /// Disables aspect ratio bucketing (uses exact dimension grouping).
1127 ///
1128 /// # Returns
1129 ///
1130 /// The updated builder instance
1131 pub fn disable_aspect_ratio_bucketing(mut self) -> Self {
1132 self.config.aspect_ratio_bucketing = None;
1133 self
1134 }
1135
1136 /// Enables dynamic batching with default configuration.
1137 ///
1138 /// This enables cross-image batching for both detection and recognition,
1139 /// which can improve performance when processing multiple images with
1140 /// compatible shapes.
1141 ///
1142 /// # Returns
1143 ///
1144 /// The updated builder instance
1145 ///
1146 /// Sets a custom dynamic batching configuration.
1147 ///
1148 /// # Arguments
1149 ///
1150 /// * `config` - The dynamic batching configuration
1151 ///
1152 /// # Returns
1153 ///
1154 /// The updated builder instance
1155 pub fn dynamic_batching_config(mut self, config: DynamicBatchConfig) -> Self {
1156 self.config.dynamic_batching = Some(config);
1157 self
1158 }
1159
1160 /// Disables dynamic batching (processes images individually).
1161 ///
1162 /// # Returns
1163 ///
1164 /// The updated builder instance
1165 pub fn disable_dynamic_batching(mut self) -> Self {
1166 self.config.dynamic_batching = None;
1167 self
1168 }
1169
1170 /// Sets the maximum batch size for detection.
1171 ///
1172 /// # Arguments
1173 ///
1174 /// * `batch_size` - Maximum number of images to batch for detection
1175 ///
1176 /// # Returns
1177 ///
1178 /// The updated builder instance
1179 pub fn max_detection_batch_size(mut self, batch_size: usize) -> Self {
1180 let bs = Self::validate_min_size_usize(batch_size, "max_detection_batch_size");
1181 with_nested!(self.config.dynamic_batching, DynamicBatchConfig, config => {
1182 config.max_detection_batch_size = bs;
1183 });
1184 self
1185 }
1186
1187 /// Sets the maximum batch size for recognition.
1188 ///
1189 /// # Arguments
1190 ///
1191 /// * `batch_size` - Maximum number of text regions to batch for recognition
1192 ///
1193 /// # Returns
1194 ///
1195 /// The updated builder instance
1196 pub fn max_recognition_batch_size(mut self, batch_size: usize) -> Self {
1197 let bs = Self::validate_min_size_usize(batch_size, "max_recognition_batch_size");
1198 with_nested!(self.config.dynamic_batching, DynamicBatchConfig, config => {
1199 config.max_recognition_batch_size = bs;
1200 });
1201 self
1202 }
1203
1204 /// Sets the minimum batch size to trigger dynamic batching.
1205 ///
1206 /// # Arguments
1207 ///
1208 /// * `min_size` - Minimum number of items needed to enable batching
1209 ///
1210 /// # Returns
1211 ///
1212 /// The updated builder instance
1213 pub fn min_batch_size(mut self, min_size: usize) -> Self {
1214 let ms = Self::validate_min_size_usize(min_size, "min_batch_size");
1215 if self.config.dynamic_batching.is_none() {
1216 self.config.dynamic_batching = Some(DynamicBatchConfig::default());
1217 }
1218 if let Some(ref mut config) = self.config.dynamic_batching {
1219 config.min_batch_size = ms;
1220 }
1221 self
1222 }
1223
1224 /// Sets the shape compatibility strategy for dynamic batching.
1225 ///
1226 /// # Arguments
1227 ///
1228 /// * `strategy` - The shape compatibility strategy to use
1229 ///
1230 /// # Returns
1231 ///
1232 /// The updated builder instance
1233 pub fn shape_compatibility_strategy(mut self, strategy: ShapeCompatibilityStrategy) -> Self {
1234 if self.config.dynamic_batching.is_none() {
1235 self.config.dynamic_batching = Some(DynamicBatchConfig::default());
1236 }
1237 if let Some(ref mut config) = self.config.dynamic_batching {
1238 config.shape_compatibility = strategy;
1239 }
1240 self
1241 }
1242
1243 /// Sets aspect ratio tolerance for shape compatibility.
1244 ///
1245 /// This is a convenience method that sets the shape compatibility strategy
1246 /// to AspectRatio with the specified tolerance.
1247 ///
1248 /// # Arguments
1249 ///
1250 /// * `tolerance` - Tolerance for aspect ratio matching (e.g., 0.1 means ±10%)
1251 ///
1252 /// # Returns
1253 ///
1254 /// The updated builder instance
1255 pub fn aspect_ratio_tolerance(mut self, tolerance: f32) -> Self {
1256 let tol = Self::validate_non_negative_f32(tolerance, "aspect_ratio_tolerance");
1257 if self.config.dynamic_batching.is_none() {
1258 self.config.dynamic_batching = Some(DynamicBatchConfig::default());
1259 }
1260 if let Some(ref mut config) = self.config.dynamic_batching {
1261 config.shape_compatibility = ShapeCompatibilityStrategy::AspectRatio { tolerance: tol };
1262 }
1263 self
1264 }
1265
1266 /// Sets exact shape matching for dynamic batching.
1267 ///
1268 /// This requires images to have identical dimensions to be batched together.
1269 ///
1270 /// # Returns
1271 ///
1272 /// The updated builder instance
1273 pub fn exact_shape_matching(mut self) -> Self {
1274 if self.config.dynamic_batching.is_none() {
1275 self.config.dynamic_batching = Some(DynamicBatchConfig::default());
1276 }
1277 if let Some(ref mut config) = self.config.dynamic_batching {
1278 config.shape_compatibility = ShapeCompatibilityStrategy::Exact;
1279 }
1280 self
1281 }
1282
1283 /// Builds the OAROCR instance with the configured parameters.
1284 ///
1285 /// # Returns
1286 ///
1287 /// A Result containing the OAROCR instance or an OCRError
1288 pub fn build(self) -> crate::core::OcrResult<super::OAROCR> {
1289 super::OAROCR::new(self.config)
1290 }
1291
1292 /// Gets a reference to the configuration for testing purposes.
1293 ///
1294 /// # Returns
1295 ///
1296 /// A reference to the OAROCRConfig
1297 #[cfg(test)]
1298 pub fn get_config(&self) -> &OAROCRConfig {
1299 &self.config
1300 }
1301}
1302
1303// Demonstration of how the impl_complete_builder! macro could be used
1304// to generate many of the builder methods automatically.
1305// This is commented out to avoid conflicts with the existing implementation,
1306// but shows how the macro could replace much of the manual code.
1307
1308impl_complete_builder! {
1309 builder: OAROCRBuilder,
1310 config_field: config,
1311 enable_methods: {
1312 enable_dynamic_batching => dynamic_batching: DynamicBatchConfig => "Enables dynamic batching with default configuration",
1313 enable_aspect_ratio_bucketing => aspect_ratio_bucketing: AspectRatioBucketingConfig => "Enables aspect ratio bucketing with default configuration",
1314 }
1315}
1316
1317#[cfg(test)]
1318mod tests {
1319 use super::*;
1320 use crate::core::config::onnx::{OrtExecutionProvider, OrtSessionConfig};
1321
1322 #[test]
1323 fn test_ort_session_configuration_propagation() {
1324 // Create an ORT session configuration
1325 let ort_config =
1326 OrtSessionConfig::new().with_execution_providers(vec![OrtExecutionProvider::CPU]);
1327
1328 // Create a builder and set the ORT session configuration
1329 let builder = OAROCRBuilder::new(
1330 "test_detection_model.onnx".to_string(),
1331 "test_recognition_model.onnx".to_string(),
1332 "test_char_dict.txt".to_string(),
1333 )
1334 .text_detection_ort_session(ort_config.clone())
1335 .text_recognition_ort_session(ort_config.clone());
1336
1337 let config = builder.get_config();
1338
1339 // Verify that the ORT session configuration was properly set
1340 assert!(config.detection.common.ort_session.is_some());
1341 assert!(config.recognition.common.ort_session.is_some());
1342
1343 let detection_ort = config.detection.common.ort_session.as_ref().unwrap();
1344 let recognition_ort = config.recognition.common.ort_session.as_ref().unwrap();
1345
1346 assert_eq!(
1347 detection_ort.execution_providers,
1348 Some(vec![OrtExecutionProvider::CPU])
1349 );
1350 assert_eq!(
1351 recognition_ort.execution_providers,
1352 Some(vec![OrtExecutionProvider::CPU])
1353 );
1354 }
1355
1356 #[test]
1357 fn test_session_pool_size_configuration_propagation() {
1358 // Create a builder and set session pool sizes
1359 let builder = OAROCRBuilder::new(
1360 "test_detection_model.onnx".to_string(),
1361 "test_recognition_model.onnx".to_string(),
1362 "test_char_dict.txt".to_string(),
1363 )
1364 .text_detection_session_pool_size(4)
1365 .text_recognition_session_pool_size(8);
1366
1367 let config = builder.get_config();
1368
1369 // Verify that the session pool sizes were properly set
1370 assert_eq!(config.detection.common.session_pool_size, Some(4));
1371 assert_eq!(config.recognition.common.session_pool_size, Some(8));
1372 }
1373
1374 #[test]
1375 fn test_global_ort_session_configuration() {
1376 // Create an ORT session configuration
1377 let ort_config =
1378 OrtSessionConfig::new().with_execution_providers(vec![OrtExecutionProvider::CPU]);
1379
1380 // Create a builder and set the global ORT session configuration
1381 let builder = OAROCRBuilder::new(
1382 "test_detection_model.onnx".to_string(),
1383 "test_recognition_model.onnx".to_string(),
1384 "test_char_dict.txt".to_string(),
1385 )
1386 .global_ort_session(ort_config.clone());
1387
1388 let config = builder.get_config();
1389
1390 // Verify that the ORT session configuration was applied to all components
1391 assert!(config.detection.common.ort_session.is_some());
1392 assert!(config.recognition.common.ort_session.is_some());
1393
1394 let detection_ort = config.detection.common.ort_session.as_ref().unwrap();
1395 let recognition_ort = config.recognition.common.ort_session.as_ref().unwrap();
1396
1397 assert_eq!(
1398 detection_ort.execution_providers,
1399 Some(vec![OrtExecutionProvider::CPU])
1400 );
1401 assert_eq!(
1402 recognition_ort.execution_providers,
1403 Some(vec![OrtExecutionProvider::CPU])
1404 );
1405 }
1406
1407 #[test]
1408 fn test_global_session_pool_size_configuration() {
1409 // Create a builder and set global session pool size
1410 let builder = OAROCRBuilder::new(
1411 "test_detection_model.onnx".to_string(),
1412 "test_recognition_model.onnx".to_string(),
1413 "test_char_dict.txt".to_string(),
1414 )
1415 .global_session_pool_size(6);
1416
1417 let config = builder.get_config();
1418
1419 // Verify that the session pool size was applied to all components
1420 assert_eq!(config.detection.common.session_pool_size, Some(6));
1421 assert_eq!(config.recognition.common.session_pool_size, Some(6));
1422 }
1423
1424 #[test]
1425 fn test_text_det_max_side_limit_configuration() {
1426 // Create a builder and set the text detection max side limit
1427 let builder = OAROCRBuilder::new(
1428 "test_detection_model.onnx".to_string(),
1429 "test_recognition_model.onnx".to_string(),
1430 "test_char_dict.txt".to_string(),
1431 )
1432 .text_det_max_side_limit(5000);
1433
1434 let config = builder.get_config();
1435
1436 // Verify that the max side limit was properly set
1437 assert_eq!(config.detection.max_side_limit, Some(5000));
1438 }
1439
1440 #[test]
1441 fn test_text_det_thresh_configuration() {
1442 // Create a builder and set the text detection threshold
1443 let builder = OAROCRBuilder::new(
1444 "test_detection_model.onnx".to_string(),
1445 "test_recognition_model.onnx".to_string(),
1446 "test_char_dict.txt".to_string(),
1447 )
1448 .text_det_threshold(0.4);
1449
1450 let config = builder.get_config();
1451
1452 // Verify that the threshold was properly set
1453 assert_eq!(config.detection.thresh, Some(0.4));
1454 }
1455
1456 #[test]
1457 fn test_text_det_box_thresh_configuration() {
1458 // Create a builder and set the text detection box threshold
1459 let builder = OAROCRBuilder::new(
1460 "test_detection_model.onnx".to_string(),
1461 "test_recognition_model.onnx".to_string(),
1462 "test_char_dict.txt".to_string(),
1463 )
1464 .text_det_box_threshold(0.7);
1465
1466 let config = builder.get_config();
1467
1468 // Verify that the box threshold was properly set
1469 assert_eq!(config.detection.box_thresh, Some(0.7));
1470 }
1471
1472 #[test]
1473 fn test_text_det_unclip_ratio_configuration() {
1474 // Create a builder and set the text detection unclip ratio
1475 let builder = OAROCRBuilder::new(
1476 "test_detection_model.onnx".to_string(),
1477 "test_recognition_model.onnx".to_string(),
1478 "test_char_dict.txt".to_string(),
1479 )
1480 .text_det_unclip_ratio(2.0);
1481
1482 let config = builder.get_config();
1483
1484 // Verify that the unclip ratio was properly set
1485 assert_eq!(config.detection.unclip_ratio, Some(2.0));
1486 }
1487
1488 #[test]
1489 fn test_text_det_all_thresholds_configuration() {
1490 // Create a builder and set all text detection thresholds
1491 let builder = OAROCRBuilder::new(
1492 "test_detection_model.onnx".to_string(),
1493 "test_recognition_model.onnx".to_string(),
1494 "test_char_dict.txt".to_string(),
1495 )
1496 .text_det_threshold(0.35)
1497 .text_det_box_threshold(0.65)
1498 .text_det_unclip_ratio(1.8);
1499
1500 let config = builder.get_config();
1501
1502 // Verify that all thresholds were properly set
1503 assert_eq!(config.detection.thresh, Some(0.35));
1504 assert_eq!(config.detection.box_thresh, Some(0.65));
1505 assert_eq!(config.detection.unclip_ratio, Some(1.8));
1506 }
1507
1508 #[test]
1509 fn test_textline_orientation_configuration_propagation() {
1510 // Create an ORT session configuration
1511 let ort_config =
1512 OrtSessionConfig::new().with_execution_providers(vec![OrtExecutionProvider::CPU]);
1513
1514 // Create a builder with text line orientation enabled and configure it
1515 let builder = OAROCRBuilder::new(
1516 "test_detection_model.onnx".to_string(),
1517 "test_recognition_model.onnx".to_string(),
1518 "test_char_dict.txt".to_string(),
1519 )
1520 .use_textline_orientation(true)
1521 .textline_orientation_ort_session(ort_config.clone())
1522 .textline_orientation_session_pool_size(3);
1523
1524 let config = builder.get_config();
1525
1526 // Verify that text line orientation is enabled and configured
1527 assert!(config.text_line_orientation.is_some());
1528
1529 let tlo_config = config.text_line_orientation.as_ref().unwrap();
1530 assert!(tlo_config.common.ort_session.is_some());
1531 assert_eq!(tlo_config.common.session_pool_size, Some(3));
1532
1533 let tlo_ort = tlo_config.common.ort_session.as_ref().unwrap();
1534 assert_eq!(
1535 tlo_ort.execution_providers,
1536 Some(vec![OrtExecutionProvider::CPU])
1537 );
1538 }
1539
1540 #[test]
1541 #[cfg(feature = "cuda")]
1542 fn test_with_cuda_convenience_method() {
1543 let builder = OAROCRBuilder::new(
1544 "test_detection_model.onnx".to_string(),
1545 "test_recognition_model.onnx".to_string(),
1546 "test_char_dict.txt".to_string(),
1547 )
1548 .with_cuda();
1549
1550 let config = builder.get_config();
1551
1552 // Verify that CUDA configuration was applied to all components
1553 assert!(config.detection.common.ort_session.is_some());
1554 assert!(config.recognition.common.ort_session.is_some());
1555
1556 let det_ort = config.detection.common.ort_session.as_ref().unwrap();
1557 if let Some(providers) = &det_ort.execution_providers {
1558 assert!(providers.len() >= 2); // Should have CUDA + CPU fallback
1559 // First should be CUDA
1560 if let OrtExecutionProvider::CUDA { device_id, .. } = &providers[0] {
1561 assert_eq!(*device_id, Some(0)); // Default device 0
1562 } else {
1563 panic!("Expected CUDA provider as first execution provider");
1564 }
1565 // Second should be CPU fallback
1566 assert!(matches!(providers[1], OrtExecutionProvider::CPU));
1567 }
1568 }
1569
1570 #[test]
1571 #[cfg(feature = "cuda")]
1572 fn test_with_cuda_device_convenience_method() {
1573 let builder = OAROCRBuilder::new(
1574 "test_detection_model.onnx".to_string(),
1575 "test_recognition_model.onnx".to_string(),
1576 "test_char_dict.txt".to_string(),
1577 )
1578 .with_cuda_device(2);
1579
1580 let config = builder.get_config();
1581
1582 let det_ort = config.detection.common.ort_session.as_ref().unwrap();
1583 if let Some(providers) = &det_ort.execution_providers {
1584 if let OrtExecutionProvider::CUDA { device_id, .. } = &providers[0] {
1585 assert_eq!(*device_id, Some(2)); // Should use device 2
1586 } else {
1587 panic!("Expected CUDA provider with device 2");
1588 }
1589 }
1590 }
1591
1592 #[test]
1593 fn test_with_high_performance_convenience_method() {
1594 let builder = OAROCRBuilder::new(
1595 "test_detection_model.onnx".to_string(),
1596 "test_recognition_model.onnx".to_string(),
1597 "test_char_dict.txt".to_string(),
1598 )
1599 .with_high_performance();
1600
1601 let config = builder.get_config();
1602
1603 // Verify that high performance parallel policy was set
1604 let policy = &config.parallel_policy;
1605
1606 // Should have reasonable thread count
1607 assert!(policy.max_threads.is_some());
1608 let max_threads = policy.max_threads.unwrap();
1609 assert!((1..=128).contains(&max_threads));
1610
1611 // Should have low thresholds for more parallel processing
1612 assert_eq!(policy.image_threshold, 2);
1613 assert_eq!(policy.text_box_threshold, 5);
1614 assert_eq!(policy.batch_threshold, 3);
1615 }
1616
1617 #[test]
1618 fn test_with_low_resource_convenience_method() {
1619 let builder = OAROCRBuilder::new(
1620 "test_detection_model.onnx".to_string(),
1621 "test_recognition_model.onnx".to_string(),
1622 "test_char_dict.txt".to_string(),
1623 )
1624 .with_low_resource();
1625
1626 let config = builder.get_config();
1627
1628 // Verify that low resource parallel policy was set
1629 let policy = &config.parallel_policy;
1630
1631 // Should limit threads for resource-constrained environments
1632 assert_eq!(policy.max_threads, Some(2));
1633
1634 // Should have higher thresholds to avoid parallel processing on small workloads
1635 assert_eq!(policy.image_threshold, 10);
1636 assert_eq!(policy.text_box_threshold, 20);
1637 assert_eq!(policy.batch_threshold, 10);
1638 }
1639
1640 #[test]
1641 fn test_validation_helper_functions() {
1642 // Test threshold validation
1643 let builder = OAROCRBuilder::new(
1644 "test_detection_model.onnx".to_string(),
1645 "test_recognition_model.onnx".to_string(),
1646 "test_char_dict.txt".to_string(),
1647 )
1648 .doc_orientation_threshold(1.5) // Out of range, should be clamped to 1.0
1649 .textline_orientation_threshold(-0.5) // Out of range, should be clamped to 0.0
1650 .text_det_threshold(0.5) // Valid
1651 .text_det_box_threshold(0.8) // Valid
1652 .text_rec_score_threshold(2.0); // Out of range, should be clamped to 1.0
1653
1654 let config = builder.get_config();
1655
1656 // Verify threshold clamping
1657 if let Some(ref orientation_config) = config.orientation_stage {
1658 assert_eq!(orientation_config.confidence_threshold, Some(1.0));
1659 }
1660 if let Some(ref tlo_config) = config.text_line_orientation_stage {
1661 assert_eq!(tlo_config.confidence_threshold, Some(0.0));
1662 }
1663 assert_eq!(config.detection.thresh, Some(0.5));
1664 assert_eq!(config.detection.box_thresh, Some(0.8));
1665 assert_eq!(config.recognition.score_thresh, Some(1.0));
1666 }
1667
1668 #[test]
1669 fn test_batch_size_validation() {
1670 // Test batch size validation
1671 let builder = OAROCRBuilder::new(
1672 "test_detection_model.onnx".to_string(),
1673 "test_recognition_model.onnx".to_string(),
1674 "test_char_dict.txt".to_string(),
1675 )
1676 .text_detection_batch_size(0) // Invalid, should be set to 1
1677 .text_recognition_batch_size(5) // Valid
1678 .global_session_pool_size(0) // Invalid, should be set to 1
1679 .max_detection_batch_size(10) // Valid
1680 .max_recognition_batch_size(0); // Invalid, should be set to 1
1681
1682 let config = builder.get_config();
1683
1684 // Verify batch size validation
1685 assert_eq!(config.detection.common.batch_size, Some(1));
1686 assert_eq!(config.recognition.common.batch_size, Some(5));
1687 assert_eq!(config.detection.common.session_pool_size, Some(1));
1688 assert_eq!(config.recognition.common.session_pool_size, Some(1));
1689
1690 if let Some(ref dynamic_config) = config.dynamic_batching {
1691 assert_eq!(dynamic_config.max_detection_batch_size, 10);
1692 assert_eq!(dynamic_config.max_recognition_batch_size, 1);
1693 }
1694 }
1695
1696 #[test]
1697 fn test_dimension_validation() {
1698 // Test dimension validation
1699 let builder = OAROCRBuilder::new(
1700 "test_detection_model.onnx".to_string(),
1701 "test_recognition_model.onnx".to_string(),
1702 "test_char_dict.txt".to_string(),
1703 )
1704 .text_det_input_shape((0, 640, 640)) // Invalid channels, should be set to 1
1705 .text_rec_input_shape((3, 0, 32)) // Invalid height, should be set to 1
1706 .textline_orientation_input_shape((0, 0)); // Invalid dimensions, should be set to 1
1707
1708 let config = builder.get_config();
1709
1710 // Verify dimension validation
1711 assert_eq!(config.detection.input_shape, Some((1, 640, 640)));
1712 assert_eq!(config.recognition.model_input_shape, Some([3, 1, 32]));
1713
1714 if let Some(ref tlo_config) = config.text_line_orientation {
1715 assert_eq!(tlo_config.input_shape, Some((1, 1)));
1716 }
1717 }
1718
1719 #[test]
1720 fn test_positive_float_validation() {
1721 // Test positive float validation
1722 let builder = OAROCRBuilder::new(
1723 "test_detection_model.onnx".to_string(),
1724 "test_recognition_model.onnx".to_string(),
1725 "test_char_dict.txt".to_string(),
1726 )
1727 .text_det_unclip_ratio(-1.0) // Invalid, should be set to 1.0
1728 .aspect_ratio_tolerance(-0.5); // Invalid, should be set to 0.0
1729
1730 let config = builder.get_config();
1731
1732 // Verify positive float validation
1733 assert_eq!(config.detection.unclip_ratio, Some(1.0));
1734
1735 if let Some(ref dynamic_config) = config.dynamic_batching
1736 && let ShapeCompatibilityStrategy::AspectRatio { tolerance } =
1737 dynamic_config.shape_compatibility
1738 {
1739 assert_eq!(tolerance, 0.0);
1740 }
1741 }
1742}