oar_ocr_core/predictors/
text_detection.rs1use super::builder::PredictorBuilderState;
6use crate::TaskPredictorBuilder;
7use crate::core::OcrResult;
8use crate::core::traits::OrtConfigurable;
9use crate::core::traits::adapter::AdapterBuilder;
10use crate::core::traits::task::ImageTaskInput;
11use crate::domain::adapters::TextDetectionAdapterBuilder;
12use crate::domain::tasks::text_detection::{TextDetectionConfig, TextDetectionTask};
13use crate::predictors::TaskPredictorCore;
14use image::RgbImage;
15use std::path::Path;
16
17#[derive(Debug, Clone)]
19pub struct TextDetectionResult {
20 pub detections: Vec<Vec<crate::domain::tasks::text_detection::Detection>>,
22}
23
24pub struct TextDetectionPredictor {
26 core: TaskPredictorCore<TextDetectionTask>,
27}
28
29impl TextDetectionPredictor {
30 pub fn builder() -> TextDetectionPredictorBuilder {
32 TextDetectionPredictorBuilder::new()
33 }
34
35 pub fn predict(&self, images: Vec<RgbImage>) -> OcrResult<TextDetectionResult> {
37 let input = ImageTaskInput::new(images);
39
40 let output = self.core.predict(input)?;
42
43 Ok(TextDetectionResult {
44 detections: output.detections,
45 })
46 }
47}
48
49#[derive(TaskPredictorBuilder)]
51#[builder(config = TextDetectionConfig)]
52pub struct TextDetectionPredictorBuilder {
53 state: PredictorBuilderState<TextDetectionConfig>,
54}
55
56impl TextDetectionPredictorBuilder {
57 pub fn new() -> Self {
59 Self {
60 state: PredictorBuilderState::new(TextDetectionConfig {
61 score_threshold: 0.3,
62 box_threshold: 0.6,
63 unclip_ratio: 1.5,
64 max_candidates: 1000,
65 limit_side_len: None,
66 limit_type: None,
67 max_side_len: None,
68 }),
69 }
70 }
71
72 pub fn score_threshold(mut self, threshold: f32) -> Self {
74 self.state.config_mut().score_threshold = threshold;
75 self
76 }
77
78 pub fn box_threshold(mut self, threshold: f32) -> Self {
80 self.state.config_mut().box_threshold = threshold;
81 self
82 }
83
84 pub fn unclip_ratio(mut self, ratio: f32) -> Self {
86 self.state.config_mut().unclip_ratio = ratio;
87 self
88 }
89
90 pub fn max_candidates(mut self, max: usize) -> Self {
92 self.state.config_mut().max_candidates = max;
93 self
94 }
95
96 pub fn build<P: AsRef<Path>>(self, model_path: P) -> OcrResult<TextDetectionPredictor> {
98 let (config, ort_config) = self.state.into_parts();
99 let mut adapter_builder = TextDetectionAdapterBuilder::new().with_config(config.clone());
100
101 if let Some(ort_cfg) = ort_config {
102 adapter_builder = adapter_builder.with_ort_config(ort_cfg);
103 }
104
105 let adapter = Box::new(adapter_builder.build(model_path.as_ref())?);
106 let task = TextDetectionTask::new(config.clone());
107
108 Ok(TextDetectionPredictor {
109 core: TaskPredictorCore::new(adapter, task, config),
110 })
111 }
112}
113
114impl Default for TextDetectionPredictorBuilder {
115 fn default() -> Self {
116 Self::new()
117 }
118}