Skip to main content

oar_ocr/oarocr/
ocr.rs

1//! High-level OCR builder API.
2//!
3//! This module provides `OAROCRBuilder` for constructing OCR pipelines with a fluent API.
4//! It simplifies the process of configuring text detection, recognition, and optional
5//! preprocessing components.
6
7use super::builder_utils::build_optional_adapter;
8use oar_ocr_core::core::config::OrtSessionConfig;
9use oar_ocr_core::core::constants::DEFAULT_REC_IMAGE_SHAPE;
10use oar_ocr_core::core::errors::OCRError;
11use oar_ocr_core::core::traits::OrtConfigurable;
12use oar_ocr_core::core::traits::adapter::{AdapterBuilder, ModelAdapter};
13use oar_ocr_core::core::traits::task::ImageTaskInput;
14use oar_ocr_core::domain::adapters::{
15    DocumentOrientationAdapter, DocumentOrientationAdapterBuilder, TextDetectionAdapter,
16    TextDetectionAdapterBuilder, TextLineOrientationAdapter, TextLineOrientationAdapterBuilder,
17    TextRecognitionAdapter, TextRecognitionAdapterBuilder, UVDocRectifierAdapter,
18    UVDocRectifierAdapterBuilder,
19};
20use oar_ocr_core::domain::tasks::{TextDetectionConfig, TextRecognitionConfig};
21use oar_ocr_core::processors::BoundingBox;
22use std::path::PathBuf;
23use std::sync::Arc;
24
25/// Internal structure holding the OCR pipeline adapters.
26#[derive(Debug)]
27struct OCRPipeline {
28    rectification_adapter: Option<UVDocRectifierAdapter>,
29    document_orientation_adapter: Option<DocumentOrientationAdapter>,
30    text_detection_adapter: TextDetectionAdapter,
31    text_line_orientation_adapter: Option<TextLineOrientationAdapter>,
32    text_recognition_adapter: TextRecognitionAdapter,
33}
34
35/// Builder for constructing OCR pipelines.
36///
37/// This builder provides a high-level API for configuring text detection and recognition
38/// pipelines with optional preprocessing components like orientation classification and
39/// image rectification.
40///
41/// # Example
42///
43/// ```no_run
44/// use oar_ocr::oarocr::OAROCRBuilder;
45///
46/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
47/// let ocr = OAROCRBuilder::new(
48///     "path/to/text_detection.onnx",
49///     "path/to/text_recognition.onnx",
50///     "path/to/character_dict.txt"
51/// )
52/// .with_document_image_orientation_classification("path/to/orientation.onnx")
53/// .with_text_line_orientation_classification("path/to/line_orientation.onnx")
54/// .image_batch_size(4)
55/// .region_batch_size(32)
56/// .build()?;
57/// # let _ = ocr;
58/// # Ok(())
59/// # }
60/// ```
61#[derive(Debug)]
62pub struct OAROCRBuilder {
63    // Required fields
64    text_detection_model: PathBuf,
65    text_recognition_model: PathBuf,
66    character_dict_path: PathBuf,
67
68    // Optional components
69    document_orientation_model: Option<PathBuf>,
70    text_line_orientation_model: Option<PathBuf>,
71    document_rectification_model: Option<PathBuf>,
72
73    // Configuration
74    ort_session_config: Option<OrtSessionConfig>,
75    text_detection_config: Option<TextDetectionConfig>,
76    text_recognition_config: Option<TextRecognitionConfig>,
77    image_batch_size: Option<usize>,
78    region_batch_size: Option<usize>,
79
80    // Text type and word box options
81    text_type: Option<String>,
82    return_word_box: bool,
83}
84
85impl OAROCRBuilder {
86    /// Creates a new OCR builder with required components.
87    ///
88    /// # Arguments
89    ///
90    /// * `text_detection_model` - Path to the text detection ONNX model
91    /// * `text_recognition_model` - Path to the text recognition ONNX model
92    /// * `character_dict_path` - Path to the character dictionary file
93    pub fn new(
94        text_detection_model: impl Into<PathBuf>,
95        text_recognition_model: impl Into<PathBuf>,
96        character_dict_path: impl Into<PathBuf>,
97    ) -> Self {
98        Self {
99            text_detection_model: text_detection_model.into(),
100            text_recognition_model: text_recognition_model.into(),
101            character_dict_path: character_dict_path.into(),
102            document_orientation_model: None,
103            text_line_orientation_model: None,
104            document_rectification_model: None,
105            ort_session_config: None,
106            text_detection_config: None,
107            text_recognition_config: None,
108            image_batch_size: None,
109            region_batch_size: None,
110            text_type: None,
111            return_word_box: false,
112        }
113    }
114
115    /// Sets the ONNX Runtime session configuration.
116    ///
117    /// This configuration will be applied to all models in the pipeline.
118    pub fn ort_session(mut self, config: OrtSessionConfig) -> Self {
119        self.ort_session_config = Some(config);
120        self
121    }
122
123    /// Sets the text detection model configuration.
124    ///
125    /// The configuration should be a JSON value containing model-specific settings.
126    pub fn text_detection_config(mut self, config: TextDetectionConfig) -> Self {
127        self.text_detection_config = Some(config);
128        self
129    }
130
131    /// Sets the text recognition model configuration.
132    ///
133    /// The configuration should be a JSON value containing model-specific settings.
134    pub fn text_recognition_config(mut self, config: TextRecognitionConfig) -> Self {
135        self.text_recognition_config = Some(config);
136        self
137    }
138
139    /// Sets the batch size for processing input images during text detection.
140    ///
141    /// This controls how many images are sent to the text detection adapter per call.
142    /// If a detector cannot batch the provided images (e.g., mismatched sizes), the
143    /// pipeline falls back to per-image detection.
144    pub fn image_batch_size(mut self, size: usize) -> Self {
145        self.image_batch_size = Some(size);
146        self
147    }
148
149    /// Sets the batch size for processing detected text regions.
150    ///
151    /// Controls memory usage during text recognition. Smaller values use less memory.
152    /// Recommended: 32 for medium VRAM, 16 for low VRAM/CPU.
153    pub fn region_batch_size(mut self, size: usize) -> Self {
154        self.region_batch_size = Some(size);
155        self
156    }
157
158    /// Adds document image orientation classification to the pipeline.
159    ///
160    /// This component detects and corrects document orientation before text detection.
161    pub fn with_document_image_orientation_classification(
162        mut self,
163        model_path: impl Into<PathBuf>,
164    ) -> Self {
165        self.document_orientation_model = Some(model_path.into());
166        self
167    }
168
169    /// Adds text line orientation classification to the pipeline.
170    ///
171    /// This component detects and corrects text line orientation after text detection.
172    pub fn with_text_line_orientation_classification(
173        mut self,
174        model_path: impl Into<PathBuf>,
175    ) -> Self {
176        self.text_line_orientation_model = Some(model_path.into());
177        self
178    }
179
180    /// Adds document image rectification to the pipeline.
181    ///
182    /// This component corrects document distortion before text detection.
183    pub fn with_document_image_rectification(mut self, model_path: impl Into<PathBuf>) -> Self {
184        self.document_rectification_model = Some(model_path.into());
185        self
186    }
187
188    /// Sets the text type for sorting and cropping strategy.
189    ///
190    /// This matches the text_type parameter:
191    /// - "seal": Uses polygon-based sorting/cropping for seal text (circular/curved)
192    /// - "table": Uses table-friendly detection defaults (box_threshold=0.4)
193    /// - Other values or None: Uses quad-based sorting (default)
194    ///
195    /// # Arguments
196    ///
197    /// * `text_type` - Text type identifier ("seal", etc.)
198    pub fn text_type(mut self, text_type: impl Into<String>) -> Self {
199        self.text_type = Some(text_type.into());
200        self
201    }
202
203    /// Enables word-level bounding box detection.
204    ///
205    /// When enabled, the pipeline will attempt to detect individual words
206    /// within each text line and populate the `word_boxes` field in `TextRegion`.
207    ///
208    /// Note: This feature requires word-level detection support in the recognition model.
209    ///
210    /// # Arguments
211    ///
212    /// * `enable` - Whether to enable word box detection
213    pub fn return_word_box(mut self, enable: bool) -> Self {
214        self.return_word_box = enable;
215        self
216    }
217
218    /// Builds the OCR runtime.
219    ///
220    /// This instantiates all adapters and returns an `OAROCR` instance ready for prediction.
221    pub fn build(self) -> Result<OAROCR, OCRError> {
222        // Load character dictionary for text recognition
223        let char_dict = std::fs::read_to_string(&self.character_dict_path).map_err(|e| {
224            OCRError::InvalidInput {
225                message: format!(
226                    "Failed to read character dictionary from '{}': {}",
227                    self.character_dict_path.display(),
228                    e
229                ),
230            }
231        })?;
232
233        // Build document rectification adapter if enabled
234        let rectification_adapter = build_optional_adapter(
235            self.document_rectification_model.as_ref(),
236            self.ort_session_config.as_ref(),
237            UVDocRectifierAdapterBuilder::new,
238        )?;
239
240        // Build document orientation adapter if enabled
241        let document_orientation_adapter = build_optional_adapter(
242            self.document_orientation_model.as_ref(),
243            self.ort_session_config.as_ref(),
244            DocumentOrientationAdapterBuilder::new,
245        )?;
246
247        // Build text detection adapter (required)
248        let mut detection_builder = TextDetectionAdapterBuilder::new();
249
250        if let Some(ref ort_config) = self.ort_session_config {
251            detection_builder = detection_builder.with_ort_config(ort_config.clone());
252        }
253
254        // Align text detection defaults with OCR pipeline.
255        // Defaults depend on text_type:
256        // - general: limit_side_len=960, limit_type="max", thresh=0.3, box_thresh=0.6, unclip_ratio=2.0
257        // - table: limit_side_len=960, limit_type="max", thresh=0.3, box_thresh=0.4, unclip_ratio=2.0
258        // - seal: limit_side_len=736, limit_type="min", thresh=0.2, box_thresh=0.6, unclip_ratio=0.5
259        let mut effective_det_cfg = self.text_detection_config.clone().unwrap_or_default();
260        let has_explicit_det_cfg = self.text_detection_config.is_some();
261        if !has_explicit_det_cfg {
262            match self.text_type.as_deref().unwrap_or("general") {
263                "table" => {
264                    effective_det_cfg.score_threshold = 0.3;
265                    effective_det_cfg.box_threshold = 0.4;
266                    effective_det_cfg.unclip_ratio = 2.0;
267                    if effective_det_cfg.limit_side_len.is_none() {
268                        effective_det_cfg.limit_side_len = Some(960);
269                    }
270                    if effective_det_cfg.limit_type.is_none() {
271                        effective_det_cfg.limit_type = Some(crate::processors::LimitType::Max);
272                    }
273                    if effective_det_cfg.max_side_len.is_none() {
274                        effective_det_cfg.max_side_len = Some(4000);
275                    }
276                }
277                "seal" => {
278                    effective_det_cfg.score_threshold = 0.2;
279                    effective_det_cfg.box_threshold = 0.6;
280                    effective_det_cfg.unclip_ratio = 0.5;
281                    if effective_det_cfg.limit_side_len.is_none() {
282                        effective_det_cfg.limit_side_len = Some(736);
283                    }
284                    if effective_det_cfg.limit_type.is_none() {
285                        effective_det_cfg.limit_type = Some(crate::processors::LimitType::Min);
286                    }
287                    if effective_det_cfg.max_side_len.is_none() {
288                        effective_det_cfg.max_side_len = Some(4000);
289                    }
290                }
291                _ => {
292                    effective_det_cfg.score_threshold = 0.3;
293                    effective_det_cfg.box_threshold = 0.6;
294                    effective_det_cfg.unclip_ratio = 2.0;
295                    if effective_det_cfg.limit_side_len.is_none() {
296                        effective_det_cfg.limit_side_len = Some(960);
297                    }
298                    if effective_det_cfg.limit_type.is_none() {
299                        effective_det_cfg.limit_type = Some(crate::processors::LimitType::Max);
300                    }
301                    if effective_det_cfg.max_side_len.is_none() {
302                        effective_det_cfg.max_side_len = Some(4000);
303                    }
304                }
305            }
306        }
307
308        detection_builder = detection_builder.with_config(effective_det_cfg);
309
310        // Pass text_type to detection adapter for proper preprocessing configuration
311        if let Some(ref text_type) = self.text_type {
312            detection_builder = detection_builder.text_type(text_type.clone());
313        }
314
315        let text_detection_adapter = detection_builder.build(&self.text_detection_model)?;
316
317        // Build text line orientation adapter if enabled
318        let text_line_orientation_adapter = build_optional_adapter(
319            self.text_line_orientation_model.as_ref(),
320            self.ort_session_config.as_ref(),
321            TextLineOrientationAdapterBuilder::new,
322        )?;
323
324        // Build text recognition adapter (required)
325        // Parse char_dict into Vec<String> - one character per line
326        let char_dict_vec: Vec<String> = char_dict.lines().map(|s| s.to_string()).collect();
327
328        let mut recognition_builder = TextRecognitionAdapterBuilder::new()
329            .character_dict(char_dict_vec)
330            .return_word_box(self.return_word_box);
331
332        if let Some(ref ort_config) = self.ort_session_config {
333            recognition_builder = recognition_builder.with_ort_config(ort_config.clone());
334        }
335
336        if let Some(ref rec_config) = self.text_recognition_config {
337            recognition_builder = recognition_builder.with_config(rec_config.clone());
338        }
339
340        let text_recognition_adapter = recognition_builder.build(&self.text_recognition_model)?;
341
342        let pipeline = OCRPipeline {
343            rectification_adapter,
344            document_orientation_adapter,
345            text_detection_adapter,
346            text_line_orientation_adapter,
347            text_recognition_adapter,
348        };
349
350        Ok(OAROCR {
351            pipeline,
352            text_type: self.text_type,
353            return_word_box: self.return_word_box,
354            image_batch_size: self.image_batch_size,
355            region_batch_size: self.region_batch_size,
356        })
357    }
358}
359
360/// OCR runtime for executing text detection and recognition.
361///
362/// This struct represents a configured OCR pipeline that can process images
363/// to extract text.
364#[derive(Debug)]
365pub struct OAROCR {
366    pipeline: OCRPipeline,
367    text_type: Option<String>,
368    return_word_box: bool,
369    /// Text detection batch size for `predict(images)`.
370    ///
371    /// This controls how many preprocessed images are sent to the text detection adapter in a
372    /// single call. If `None`, the adapter's `recommended_batch_size()` is used.
373    image_batch_size: Option<usize>,
374    /// Batch size for text region recognition
375    region_batch_size: Option<usize>,
376}
377
378struct CroppedTextRegion {
379    detection_index: usize,
380    bbox: BoundingBox,
381    image: image::RgbImage,
382    wh_ratio: f32,
383    line_orientation_angle: Option<f32>,
384}
385
386impl std::fmt::Debug for CroppedTextRegion {
387    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        f.debug_struct("CroppedTextRegion")
389            .field("detection_index", &self.detection_index)
390            .field("bbox", &self.bbox)
391            .field(
392                "image",
393                &format_args!("RgbImage({}x{})", self.image.width(), self.image.height()),
394            )
395            .field("wh_ratio", &self.wh_ratio)
396            .field("line_orientation_angle", &self.line_orientation_angle)
397            .finish()
398    }
399}
400
401impl OAROCR {
402    /// Predicts text from images using the configured OCR pipeline.
403    ///
404    /// This method orchestrates the execution of all configured tasks in the pipeline,
405    /// including optional components like document orientation, rectification, and
406    /// text line orientation classification.
407    ///
408    /// # Arguments
409    ///
410    /// * `images` - Collection of RGB images to process
411    ///
412    /// # Returns
413    ///
414    /// A vector of `OAROCRResult` containing the OCR results for each image,
415    /// or an error if processing fails.
416    ///
417    /// # Example
418    ///
419    /// ```no_run
420    /// use oar_ocr::oarocr::ocr::OAROCRBuilder;
421    /// use oar_ocr::utils::load_image;
422    /// use std::path::Path;
423    ///
424    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
425    /// let ocr = OAROCRBuilder::new(
426    ///     "models/det.onnx",
427    ///     "models/rec.onnx",
428    ///     "models/dict.txt",
429    /// ).build()?;
430    ///
431    /// let image = load_image(Path::new("document.jpg"))?;
432    /// let results = ocr.predict(vec![image])?;
433    ///
434    /// for result in results {
435    ///     for region in result.text_regions {
436    ///         if let Some(text) = region.text {
437    ///             println!("Text: {}", text);
438    ///         }
439    ///     }
440    /// }
441    /// # Ok(())
442    /// # }
443    /// ```
444    pub fn predict(
445        &self,
446        images: Vec<image::RgbImage>,
447    ) -> Result<Vec<crate::oarocr::OAROCRResult>, OCRError> {
448        use crate::oarocr::preprocess::DocumentPreprocessor;
449        use std::sync::Arc;
450
451        if images.is_empty() {
452            return Err(OCRError::validation_error(
453                "OCR Pipeline",
454                "images",
455                "non-empty slice",
456                "empty slice",
457            ));
458        }
459
460        let preprocessor = DocumentPreprocessor::new(
461            self.pipeline.document_orientation_adapter.as_ref(),
462            self.pipeline.rectification_adapter.as_ref(),
463        );
464
465        let mut prepared: Vec<(
466            Arc<image::RgbImage>,
467            crate::oarocr::preprocess::PreprocessResult,
468        )> = Vec::with_capacity(images.len());
469
470        for image in images.into_iter() {
471            let input_img_arc = Arc::new(image);
472            let preprocess = preprocessor.preprocess(Arc::clone(&input_img_arc))?;
473            prepared.push((input_img_arc, preprocess));
474        }
475
476        let det_batch_size = self
477            .image_batch_size
478            .unwrap_or_else(|| {
479                self.pipeline
480                    .text_detection_adapter
481                    .recommended_batch_size()
482            })
483            .max(1);
484
485        let mut all_detection_boxes: Vec<Vec<BoundingBox>> = vec![Vec::new(); prepared.len()];
486
487        let mut start = 0usize;
488        while start < prepared.len() {
489            let end = (start + det_batch_size).min(prepared.len());
490
491            // Adapter boundary: must clone to transfer ownership
492            let batch_images: Vec<image::RgbImage> = prepared[start..end]
493                .iter()
494                .map(|(_, preprocess)| (*preprocess.image).clone())
495                .collect();
496
497            match self.detect_sorted_text_boxes_batch(batch_images) {
498                Ok(batch_boxes) => {
499                    for (offset, boxes) in batch_boxes.into_iter().enumerate() {
500                        all_detection_boxes[start + offset] = boxes;
501                    }
502                }
503                Err(err) => {
504                    tracing::warn!(
505                        target: "ocr",
506                        error = %err,
507                        batch_start = start,
508                        batch_end = end,
509                        "Batched text detection failed; falling back to per-image detection"
510                    );
511                    for i in start..end {
512                        all_detection_boxes[i] =
513                            self.detect_sorted_text_boxes(&prepared[i].1.image)?;
514                    }
515                }
516            }
517
518            start = end;
519        }
520
521        let mut results = Vec::with_capacity(prepared.len());
522        for (img_idx, (input_img_arc, preprocess)) in prepared.into_iter().enumerate() {
523            let detection_boxes = all_detection_boxes[img_idx].clone();
524            results.push(self.predict_single(
525                img_idx,
526                input_img_arc,
527                preprocess,
528                detection_boxes,
529            )?);
530        }
531
532        Ok(results)
533    }
534
535    fn predict_single(
536        &self,
537        img_idx: usize,
538        input_img: std::sync::Arc<image::RgbImage>,
539        preprocess: crate::oarocr::preprocess::PreprocessResult,
540        detection_boxes: Vec<BoundingBox>,
541    ) -> Result<crate::oarocr::OAROCRResult, OCRError> {
542        use std::sync::Arc;
543
544        let current_image = preprocess.image;
545
546        let mut cropped_regions = self.crop_text_regions(&current_image, &detection_boxes)?;
547        self.classify_line_orientations(&mut cropped_regions)?;
548
549        let recognized = self.recognize_text_regions(detection_boxes.len(), cropped_regions)?;
550
551        // Preserve reading order by emitting in detection-index order.
552        let mut text_regions: Vec<crate::oarocr::TextRegion> =
553            recognized.into_iter().flatten().collect();
554
555        if let Some(rot) = preprocess.rotation {
556            Self::rotate_text_regions_back(&mut text_regions, rot);
557        }
558
559        Ok(crate::oarocr::OAROCRResult {
560            input_path: Arc::from(format!("image_{}", img_idx)),
561            index: img_idx,
562            input_img,
563            text_regions,
564            orientation_angle: preprocess.orientation_angle,
565            rectified_img: preprocess.rectified_img,
566        })
567    }
568
569    fn detect_sorted_text_boxes_batch(
570        &self,
571        images: Vec<image::RgbImage>,
572    ) -> Result<Vec<Vec<BoundingBox>>, OCRError> {
573        if images.is_empty() {
574            return Ok(Vec::new());
575        }
576
577        let input = ImageTaskInput::new(images);
578        let det = self.pipeline.text_detection_adapter.execute(input, None)?;
579
580        let mut results: Vec<Vec<BoundingBox>> = Vec::with_capacity(det.detections.len());
581        for detections in det.detections.into_iter() {
582            let boxes = detections.into_iter().map(|d| d.bbox).collect::<Vec<_>>();
583            results.push(self.sort_detection_boxes(&boxes));
584        }
585
586        Ok(results)
587    }
588
589    fn detect_sorted_text_boxes(
590        &self,
591        image: &image::RgbImage,
592    ) -> Result<Vec<BoundingBox>, OCRError> {
593        let input = ImageTaskInput::new(vec![image.clone()]);
594        let det = self.pipeline.text_detection_adapter.execute(input, None)?;
595
596        let boxes = det
597            .detections
598            .into_iter()
599            .next()
600            .unwrap_or_default()
601            .into_iter()
602            .map(|d| d.bbox)
603            .collect::<Vec<_>>();
604
605        Ok(self.sort_detection_boxes(&boxes))
606    }
607
608    fn sort_detection_boxes(&self, boxes: &[BoundingBox]) -> Vec<BoundingBox> {
609        if boxes.is_empty() {
610            return Vec::new();
611        }
612
613        let is_seal_text = self
614            .text_type
615            .as_ref()
616            .map(|t| t.to_lowercase() == "seal")
617            .unwrap_or(false);
618
619        if is_seal_text {
620            crate::processors::sort_poly_boxes(boxes)
621        } else {
622            crate::processors::sort_quad_boxes(boxes)
623        }
624    }
625
626    fn crop_text_regions(
627        &self,
628        image: &Arc<image::RgbImage>,
629        detection_boxes: &[BoundingBox],
630    ) -> Result<Vec<CroppedTextRegion>, OCRError> {
631        use crate::oarocr::EdgeProcessor;
632        use crate::oarocr::TextCroppingProcessor;
633
634        if detection_boxes.is_empty() {
635            return Ok(Vec::new());
636        }
637
638        let processor = TextCroppingProcessor::new(true); // handle_rotation = true
639        // Zero-copy: share Arc instead of cloning the image
640        let cropped = processor.process((Arc::clone(image), detection_boxes.to_vec()))?;
641
642        let mut regions = Vec::new();
643        for (idx, crop_result) in cropped.into_iter().enumerate() {
644            let Some(img) = crop_result else {
645                continue;
646            };
647            // Small cropped images: clone is acceptable
648            let img = (*img).clone();
649            let wh_ratio = img.width() as f32 / img.height().max(1) as f32;
650            regions.push(CroppedTextRegion {
651                detection_index: idx,
652                bbox: detection_boxes
653                    .get(idx)
654                    .cloned()
655                    .unwrap_or_else(|| BoundingBox::from_coords(0.0, 0.0, 0.0, 0.0)),
656                image: img,
657                wh_ratio,
658                line_orientation_angle: None,
659            });
660        }
661
662        Ok(regions)
663    }
664
665    fn classify_line_orientations(
666        &self,
667        regions: &mut [CroppedTextRegion],
668    ) -> Result<(), OCRError> {
669        let Some(ref line_orientation_adapter) = self.pipeline.text_line_orientation_adapter else {
670            return Ok(());
671        };
672
673        if regions.is_empty() {
674            return Ok(());
675        }
676
677        let input_images = regions.iter().map(|r| r.image.clone()).collect();
678        let input = ImageTaskInput::new(input_images);
679        let orient = line_orientation_adapter.execute(input, None)?;
680
681        for (idx, classifications) in orient
682            .classifications
683            .iter()
684            .enumerate()
685            .take(regions.len())
686        {
687            let Some(top_class) = classifications.first() else {
688                continue;
689            };
690
691            // Convert class_id to angle (0=0°, 1=180°)
692            let angle = (top_class.class_id as f32) * 180.0;
693            regions[idx].line_orientation_angle = Some(angle);
694
695            if top_class.class_id == 1 {
696                regions[idx].image = image::imageops::rotate180(&regions[idx].image);
697            }
698        }
699
700        Ok(())
701    }
702
703    fn recognize_text_regions(
704        &self,
705        detection_count: usize,
706        mut regions: Vec<CroppedTextRegion>,
707    ) -> Result<Vec<Option<crate::oarocr::TextRegion>>, OCRError> {
708        let mut results: Vec<Option<crate::oarocr::TextRegion>> = vec![None; detection_count];
709        if regions.is_empty() {
710            return Ok(results);
711        }
712
713        regions.sort_by(|a, b| {
714            a.wh_ratio
715                .partial_cmp(&b.wh_ratio)
716                .unwrap_or(std::cmp::Ordering::Equal)
717        });
718
719        let base_rec_ratio = DEFAULT_REC_IMAGE_SHAPE[2] as f32 / DEFAULT_REC_IMAGE_SHAPE[1] as f32;
720        let batch_size = self.region_batch_size.unwrap_or(regions.len()).max(1);
721
722        for chunk in regions.chunks(batch_size) {
723            let chunk_max_wh_ratio = chunk
724                .iter()
725                .map(|r| r.wh_ratio)
726                .fold(base_rec_ratio, |acc, r| acc.max(r));
727
728            let rec_input = ImageTaskInput::new(chunk.iter().map(|r| r.image.clone()).collect());
729
730            let rec = self
731                .pipeline
732                .text_recognition_adapter
733                .execute(rec_input, None)?;
734
735            let n = rec.texts.len().min(chunk.len());
736            for (i, region) in chunk.iter().take(n).enumerate() {
737                let text = rec.texts.get(i).map(String::as_str).unwrap_or("");
738                let score = *rec.scores.get(i).unwrap_or(&0.0);
739
740                let char_positions: &[f32] = rec
741                    .char_positions
742                    .get(i)
743                    .map(|v| v.as_slice())
744                    .unwrap_or(&[]);
745                let col_indices: &[usize] = rec
746                    .char_col_indices
747                    .get(i)
748                    .map(|v| v.as_slice())
749                    .unwrap_or(&[]);
750                let seq_len = *rec.sequence_lengths.get(i).unwrap_or(&0);
751
752                let bbox = region.bbox.clone();
753                let word_boxes = if self.return_word_box && !col_indices.is_empty() && seq_len > 0 {
754                    Some(Self::ctc_word_boxes(
755                        &bbox,
756                        text,
757                        col_indices,
758                        seq_len,
759                        region.wh_ratio,
760                        chunk_max_wh_ratio,
761                    ))
762                } else if self.return_word_box && !char_positions.is_empty() {
763                    Some(Self::char_positions_to_word_boxes(
764                        &bbox,
765                        char_positions,
766                        text.chars().count(),
767                    ))
768                } else {
769                    None
770                };
771
772                if region.detection_index < results.len() {
773                    results[region.detection_index] = Some(crate::oarocr::TextRegion {
774                        bounding_box: bbox.clone(),
775                        dt_poly: Some(bbox.clone()),
776                        rec_poly: Some(bbox),
777                        text: Some(std::sync::Arc::from(text)),
778                        confidence: Some(score),
779                        orientation_angle: region.line_orientation_angle,
780                        word_boxes,
781                        label: None,
782                    });
783                }
784            }
785        }
786
787        Ok(results)
788    }
789
790    fn rotate_text_regions_back(
791        regions: &mut [crate::oarocr::TextRegion],
792        rot: crate::oarocr::preprocess::OrientationCorrection,
793    ) {
794        for region in regions {
795            region.dt_poly = region.dt_poly.take().map(|poly| {
796                poly.rotate_back_to_original(rot.angle, rot.rotated_width, rot.rotated_height)
797            });
798            region.rec_poly = region.rec_poly.take().map(|poly| {
799                poly.rotate_back_to_original(rot.angle, rot.rotated_width, rot.rotated_height)
800            });
801            region.bounding_box = region.bounding_box.rotate_back_to_original(
802                rot.angle,
803                rot.rotated_width,
804                rot.rotated_height,
805            );
806
807            if let Some(ref word_boxes) = region.word_boxes {
808                let transformed_word_boxes: Vec<_> = word_boxes
809                    .iter()
810                    .map(|wb| {
811                        wb.rotate_back_to_original(rot.angle, rot.rotated_width, rot.rotated_height)
812                    })
813                    .collect();
814                region.word_boxes = Some(transformed_word_boxes);
815            }
816        }
817    }
818
819    /// Converts CTC column indices to word-level bounding boxes using standard approach.
820    ///
821    /// This method calculates character-specific widths based on the column indices from CTC decoding,
822    /// which provides more accurate word boxes than uniform distribution.
823    ///
824    /// It aligns with standard logic by distinguishing between CJK and other characters:
825    /// - CJK characters use a center-based approach with average character width to avoid being too narrow.
826    /// - Other characters use the standard column-based width.
827    ///
828    /// # Arguments
829    ///
830    /// * `line_bbox` - The bounding box of the entire text line
831    /// * `col_indices` - Column indices (timesteps) for each character from CTC output
832    /// * `seq_len` - Total number of columns (sequence length) in the CTC output
833    /// * `text` - The recognized text string
834    ///
835    /// # Returns
836    ///
837    /// A vector of bounding boxes, one for each character
838    fn ctc_word_boxes(
839        line_bbox: &BoundingBox,
840        text: &str,
841        col_indices: &[usize],
842        seq_len: usize,
843        wh_ratio: f32,
844        max_wh_ratio: f32,
845    ) -> Vec<BoundingBox> {
846        if col_indices.is_empty() || seq_len == 0 || text.is_empty() {
847            return Vec::new();
848        }
849
850        // Scale effective column count using standard logic (handles padding to max width)
851        let effective_col_num = (seq_len as f32) * (wh_ratio / max_wh_ratio);
852        if effective_col_num <= f32::EPSILON {
853            return Vec::new();
854        }
855
856        // Get the line bounding box coordinates
857        let x_min = line_bbox.x_min();
858        let y_min = line_bbox.y_min();
859        let x_max = line_bbox.x_max();
860        let y_max = line_bbox.y_max();
861        let width = x_max - x_min;
862
863        // Calculate cell width (width of each column in the CTC output)
864        let cell_width = width / effective_col_num.max(f32::EPSILON);
865
866        let mut word_boxes = Vec::new();
867        let chars: Vec<char> = text.chars().collect();
868        let avg_char_width = width / chars.len().max(1) as f32;
869
870        // Pre-calculate centers for all characters
871        let centers: Vec<f32> = col_indices
872            .iter()
873            .map(|&idx| x_min + (idx as f32 + 0.5) * cell_width)
874            .collect();
875
876        for (i, _) in col_indices.iter().enumerate() {
877            let ch = chars.get(i).copied().unwrap_or('?');
878            let center_x = centers[i];
879
880            if Self::is_cjk(ch) {
881                let half_width = avg_char_width / 2.0;
882                let char_x_min = (center_x - half_width).max(x_min);
883                let char_x_max = (center_x + half_width).min(x_max);
884                let char_box = BoundingBox::from_coords(char_x_min, y_min, char_x_max, y_max);
885                word_boxes.push(char_box);
886            } else {
887                // For non-CJK characters, use the midpoint between adjacent character centers
888                // to determine boundaries. This provides contiguous boxes that adapt to character density.
889                let char_x_min = if i == 0 {
890                    x_min
891                } else {
892                    (centers[i - 1] + center_x) / 2.0
893                }
894                .max(x_min);
895
896                let char_x_max = if i == col_indices.len() - 1 {
897                    x_max
898                } else {
899                    (center_x + centers[i + 1]) / 2.0
900                }
901                .min(x_max);
902
903                let char_box = BoundingBox::from_coords(char_x_min, y_min, char_x_max, y_max);
904                word_boxes.push(char_box);
905            }
906        }
907
908        word_boxes
909    }
910
911    /// Converts normalized character positions to word-level bounding boxes.
912    ///
913    /// This is a fallback method that uses uniform character width distribution.
914    /// Use col_indices_to_word_boxes when CTC column indices are available for better accuracy.
915    ///
916    /// # Arguments
917    ///
918    /// * `line_bbox` - The bounding box of the entire text line
919    /// * `char_positions` - Normalized x-positions (0.0-1.0) for each character
920    /// * `char_count` - Number of characters in the text
921    ///
922    /// # Returns
923    ///
924    /// A vector of bounding boxes, one for each character/word
925    fn char_positions_to_word_boxes(
926        line_bbox: &BoundingBox,
927        char_positions: &[f32],
928        char_count: usize,
929    ) -> Vec<BoundingBox> {
930        if char_positions.is_empty() || char_count == 0 {
931            return Vec::new();
932        }
933
934        // Get the line bounding box coordinates
935        let x_min = line_bbox.x_min();
936        let y_min = line_bbox.y_min();
937        let x_max = line_bbox.x_max();
938        let y_max = line_bbox.y_max();
939        let width = x_max - x_min;
940
941        // Calculate approximate character width
942        let char_width = width / char_count as f32;
943
944        // Create a bounding box for each character based on its position
945        let mut word_boxes = Vec::new();
946        for &pos in char_positions.iter() {
947            // Calculate x position (pos is normalized 0.0-1.0)
948            let char_x_center = x_min + (pos * width);
949
950            // Estimate character box boundaries
951            // Use half character width on each side of the position
952            let char_x_min = (char_x_center - char_width / 2.0).max(x_min);
953            let char_x_max = (char_x_center + char_width / 2.0).min(x_max);
954
955            // Use the full height of the text line for each character
956            let char_box = BoundingBox::from_coords(char_x_min, y_min, char_x_max, y_max);
957            word_boxes.push(char_box);
958        }
959
960        word_boxes
961    }
962
963    /// Detect whether a character is CJK.
964    fn is_cjk(c: char) -> bool {
965        let u = c as u32;
966        (0x4E00..=0x9FFF).contains(&u)
967            || (0x3400..=0x4DBF).contains(&u)
968            || (0x20000..=0x2A6DF).contains(&u)
969            || (0x2A700..=0x2B73F).contains(&u)
970            || (0x2B740..=0x2B81F).contains(&u)
971    }
972}
973
974#[cfg(test)]
975mod tests {
976    use super::*;
977
978    #[test]
979    fn test_oarocr_builder_new() {
980        let builder = OAROCRBuilder::new("models/det.onnx", "models/rec.onnx", "models/dict.txt");
981
982        assert_eq!(
983            builder.text_detection_model,
984            PathBuf::from("models/det.onnx")
985        );
986        assert_eq!(
987            builder.text_recognition_model,
988            PathBuf::from("models/rec.onnx")
989        );
990        assert_eq!(
991            builder.character_dict_path,
992            PathBuf::from("models/dict.txt")
993        );
994        assert!(builder.document_orientation_model.is_none());
995        assert!(builder.text_line_orientation_model.is_none());
996        assert!(builder.document_rectification_model.is_none());
997    }
998
999    #[test]
1000    fn test_oarocr_builder_with_optional_components() {
1001        let builder = OAROCRBuilder::new("models/det.onnx", "models/rec.onnx", "models/dict.txt")
1002            .with_document_image_orientation_classification("models/doc_orient.onnx")
1003            .with_text_line_orientation_classification("models/line_orient.onnx")
1004            .with_document_image_rectification("models/rectify.onnx");
1005
1006        let Some(path) = builder.document_orientation_model.as_ref() else {
1007            panic!("expected document_orientation_model to be Some");
1008        };
1009        assert_eq!(path, &PathBuf::from("models/doc_orient.onnx"));
1010        let Some(path) = builder.text_line_orientation_model.as_ref() else {
1011            panic!("expected text_line_orientation_model to be Some");
1012        };
1013        assert_eq!(path, &PathBuf::from("models/line_orient.onnx"));
1014        let Some(path) = builder.document_rectification_model.as_ref() else {
1015            panic!("expected document_rectification_model to be Some");
1016        };
1017        assert_eq!(path, &PathBuf::from("models/rectify.onnx"));
1018    }
1019
1020    #[test]
1021    fn test_oarocr_builder_with_configuration() {
1022        let det_config = TextDetectionConfig {
1023            score_threshold: 0.5,
1024            box_threshold: 0.6,
1025            unclip_ratio: 1.8,
1026            max_candidates: 1000,
1027            limit_side_len: None,
1028            limit_type: None,
1029            max_side_len: None,
1030        };
1031
1032        let rec_config = TextRecognitionConfig {
1033            score_threshold: 0.7,
1034            max_text_length: 128,
1035        };
1036
1037        let builder = OAROCRBuilder::new("models/det.onnx", "models/rec.onnx", "models/dict.txt")
1038            .text_detection_config(det_config.clone())
1039            .text_recognition_config(rec_config.clone());
1040
1041        assert!(builder.text_detection_config.is_some());
1042        assert!(builder.text_recognition_config.is_some());
1043    }
1044
1045    #[test]
1046    fn test_oarocr_builder_with_batch_sizes() {
1047        let builder = OAROCRBuilder::new("models/det.onnx", "models/rec.onnx", "models/dict.txt")
1048            .image_batch_size(4)
1049            .region_batch_size(64);
1050
1051        assert_eq!(builder.image_batch_size, Some(4));
1052        assert_eq!(builder.region_batch_size, Some(64));
1053    }
1054
1055    #[test]
1056    fn test_ctc_word_boxes_logic() {
1057        let line_bbox = BoundingBox::from_coords(0.0, 0.0, 100.0, 20.0);
1058        // seq_len=10, wh_ratio=5 (100/20), max_wh_ratio=5 -> effective_col_num = 10
1059        // cell_width = 100/10 = 10.0
1060
1061        // Test 1: Non-CJK "ABC"
1062        // Indices: 1, 4, 7 (approx centers: 15, 45, 75)
1063        let text = "ABC";
1064        let col_indices = vec![1, 4, 7];
1065        let seq_len = 10;
1066        let wh_ratio = 5.0;
1067        let max_wh_ratio = 5.0;
1068
1069        let boxes = OAROCR::ctc_word_boxes(
1070            &line_bbox,
1071            text,
1072            &col_indices,
1073            seq_len,
1074            wh_ratio,
1075            max_wh_ratio,
1076        );
1077
1078        assert_eq!(boxes.len(), 3);
1079        // Center 0: 1.5 * 10 = 15. Center 1: 4.5 * 10 = 45. Center 2: 7.5 * 10 = 75.
1080        // Box 0: Left=0, Right=(15+45)/2 = 30.
1081        // Box 1: Left=30, Right=(45+75)/2 = 60.
1082        // Box 2: Left=60, Right=100.
1083
1084        assert!((boxes[0].x_min() - 0.0).abs() < 1e-5);
1085        assert!((boxes[0].x_max() - 30.0).abs() < 1e-5);
1086        assert!((boxes[1].x_min() - 30.0).abs() < 1e-5);
1087        assert!((boxes[1].x_max() - 60.0).abs() < 1e-5);
1088        assert!((boxes[2].x_min() - 60.0).abs() < 1e-5);
1089        assert!((boxes[2].x_max() - 100.0).abs() < 1e-5);
1090    }
1091}