oar_ocr_core/predictors/
text_detection.rs

1//! Text Detection Predictor
2//!
3//! This module provides a high-level API for text detection in images.
4
5use super::builder::PredictorBuilderState;
6use crate::TaskPredictorBuilder;
7use crate::core::OcrResult;
8use crate::core::traits::OrtConfigurable;
9use crate::core::traits::adapter::AdapterBuilder;
10use crate::core::traits::task::ImageTaskInput;
11use crate::domain::adapters::TextDetectionAdapterBuilder;
12use crate::domain::tasks::text_detection::{TextDetectionConfig, TextDetectionTask};
13use crate::predictors::TaskPredictorCore;
14use image::RgbImage;
15use std::path::Path;
16
17/// Text detection prediction result
18#[derive(Debug, Clone)]
19pub struct TextDetectionResult {
20    /// Detected text regions for each input image
21    pub detections: Vec<Vec<crate::domain::tasks::text_detection::Detection>>,
22}
23
24/// Text detection predictor
25pub struct TextDetectionPredictor {
26    core: TaskPredictorCore<TextDetectionTask>,
27}
28
29impl TextDetectionPredictor {
30    /// Create a new builder for the text detection predictor
31    pub fn builder() -> TextDetectionPredictorBuilder {
32        TextDetectionPredictorBuilder::new()
33    }
34
35    /// Predict text regions in the given images.
36    pub fn predict(&self, images: Vec<RgbImage>) -> OcrResult<TextDetectionResult> {
37        // Create task input
38        let input = ImageTaskInput::new(images);
39
40        // Use core predictor for validation and execution
41        let output = self.core.predict(input)?;
42
43        Ok(TextDetectionResult {
44            detections: output.detections,
45        })
46    }
47}
48
49/// Builder for text detection predictor
50#[derive(TaskPredictorBuilder)]
51#[builder(config = TextDetectionConfig)]
52pub struct TextDetectionPredictorBuilder {
53    state: PredictorBuilderState<TextDetectionConfig>,
54}
55
56impl TextDetectionPredictorBuilder {
57    /// Create a new builder with default configuration
58    pub fn new() -> Self {
59        Self {
60            state: PredictorBuilderState::new(TextDetectionConfig {
61                score_threshold: 0.3,
62                box_threshold: 0.6,
63                unclip_ratio: 1.5,
64                max_candidates: 1000,
65                limit_side_len: None,
66                limit_type: None,
67                max_side_len: None,
68            }),
69        }
70    }
71
72    /// Set the score threshold
73    pub fn score_threshold(mut self, threshold: f32) -> Self {
74        self.state.config_mut().score_threshold = threshold;
75        self
76    }
77
78    /// Set the box threshold
79    pub fn box_threshold(mut self, threshold: f32) -> Self {
80        self.state.config_mut().box_threshold = threshold;
81        self
82    }
83
84    /// Set the unclip ratio
85    pub fn unclip_ratio(mut self, ratio: f32) -> Self {
86        self.state.config_mut().unclip_ratio = ratio;
87        self
88    }
89
90    /// Set the maximum candidates
91    pub fn max_candidates(mut self, max: usize) -> Self {
92        self.state.config_mut().max_candidates = max;
93        self
94    }
95
96    /// Build the text detection predictor
97    pub fn build<P: AsRef<Path>>(self, model_path: P) -> OcrResult<TextDetectionPredictor> {
98        let (config, ort_config) = self.state.into_parts();
99        let mut adapter_builder = TextDetectionAdapterBuilder::new().with_config(config.clone());
100
101        if let Some(ort_cfg) = ort_config {
102            adapter_builder = adapter_builder.with_ort_config(ort_cfg);
103        }
104
105        let adapter = Box::new(adapter_builder.build(model_path.as_ref())?);
106        let task = TextDetectionTask::new(config.clone());
107
108        Ok(TextDetectionPredictor {
109            core: TaskPredictorCore::new(adapter, task, config),
110        })
111    }
112}
113
114impl Default for TextDetectionPredictorBuilder {
115    fn default() -> Self {
116        Self::new()
117    }
118}