1use crate::core::ImageReader as CoreImageReader;
14use crate::core::{
15 BatchData, CommonBuilderConfig, ConfigValidator, ConfigValidatorExt, DefaultImageReader,
16 OCRError, OrtInfer, Tensor3D, Tensor4D,
17};
18use crate::core::{
19 GranularImageReader as GIReader, ModularPredictor, OrtInfer3D, Postprocessor as GPostprocessor,
20 Preprocessor as GPreprocessor,
21};
22use crate::impl_common_builder_methods;
23use crate::impl_config_new_and_with_common;
24use crate::processors::{CTCLabelDecode, NormalizeImage, OCRResize};
25
26use image::RgbImage;
27use std::path::Path;
28use std::sync::Arc;
29
30#[derive(Debug, Clone)]
35pub struct TextRecResult {
36 pub input_path: Vec<Arc<str>>,
38 pub index: Vec<usize>,
40 pub input_img: Vec<Arc<RgbImage>>,
42 pub rec_text: Vec<Arc<str>>,
44 pub rec_score: Vec<f32>,
46}
47
48#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
52pub struct TextRecPredictorConfig {
53 pub common: CommonBuilderConfig,
55 pub model_input_shape: Option<[usize; 3]>,
59 pub character_dict: Option<Vec<String>>,
61 pub score_thresh: Option<f32>,
63}
64
65impl_config_new_and_with_common!(
66 TextRecPredictorConfig,
67 common_defaults: (Some("crnn".to_string()), Some(32)),
68 fields: {
69 model_input_shape: Some([3, 48, 320]),
70 character_dict: None,
71 score_thresh: None
72 }
73);
74
75impl ConfigValidator for TextRecPredictorConfig {
76 fn validate(&self) -> Result<(), crate::core::ConfigError> {
77 self.common.validate()?;
78
79 if let Some(shape) = self.model_input_shape
80 && (shape[0] == 0 || shape[1] == 0 || shape[2] == 0)
81 {
82 return Err(crate::core::ConfigError::InvalidConfig {
83 message: "Model input shape dimensions must be greater than 0".to_string(),
84 });
85 }
86
87 Ok(())
88 }
89
90 fn get_defaults() -> Self {
91 Self::new()
92 }
93}
94
95impl TextRecResult {
96 pub fn new() -> Self {
98 Self {
99 input_path: Vec::new(),
100 index: Vec::new(),
101 input_img: Vec::new(),
102 rec_text: Vec::new(),
103 rec_score: Vec::new(),
104 }
105 }
106}
107
108impl Default for TextRecResult {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114pub type TextRecPredictor =
119 ModularPredictor<TRImageReader, TRPreprocessor, OrtInfer3D, TRPostprocessor>;
120
121#[derive(Debug)]
122pub struct TRImageReader {
123 inner: DefaultImageReader,
124}
125impl TRImageReader {
126 pub fn new() -> Self {
127 Self {
128 inner: DefaultImageReader::new(),
129 }
130 }
131}
132impl Default for TRImageReader {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137impl GIReader for TRImageReader {
138 fn read_images<'a>(
139 &self,
140 paths: impl Iterator<Item = &'a str>,
141 ) -> Result<Vec<RgbImage>, OCRError> {
142 self.inner.apply(paths)
143 }
144}
145
146#[derive(Debug)]
147pub struct TRPreprocessor {
148 pub resize: OCRResize,
149 pub normalize: NormalizeImage,
150}
151impl GPreprocessor for TRPreprocessor {
152 type Config = TextRecConfig;
153 type Output = Tensor4D;
154 fn preprocess(
155 &self,
156 images: Vec<RgbImage>,
157 _config: Option<&Self::Config>,
158 ) -> Result<Self::Output, OCRError> {
159 let resized_imgs = self.resize.apply_to_images(&images)?;
160 let dynamic_imgs: Vec<image::DynamicImage> = resized_imgs
161 .into_iter()
162 .map(image::DynamicImage::ImageRgb8)
163 .collect();
164 self.normalize.normalize_batch_to(dynamic_imgs)
165 }
166}
167
168#[derive(Debug)]
169pub struct TRPostprocessor {
170 pub decoder: CTCLabelDecode,
171}
172impl GPostprocessor for TRPostprocessor {
173 type Config = TextRecConfig;
174 type InferenceOutput = Tensor3D;
175 type PreprocessOutput = Tensor4D;
176 type Result = TextRecResult;
177 fn postprocess(
178 &self,
179 output: Self::InferenceOutput,
180 _pre: Option<&Self::PreprocessOutput>,
181 batch_data: &BatchData,
182 raw_images: Vec<RgbImage>,
183 _config: Option<&Self::Config>,
184 ) -> crate::core::OcrResult<Self::Result> {
185 let (texts, scores) = self.decoder.apply(&output);
186 Ok(TextRecResult {
187 input_path: batch_data.input_paths.clone(),
188 index: batch_data.indexes.clone(),
189 input_img: raw_images.into_iter().map(Arc::new).collect(),
190 rec_text: texts.into_iter().map(Arc::from).collect(),
191 rec_score: scores,
192 })
193 }
194 fn empty_result(&self) -> crate::core::OcrResult<Self::Result> {
195 Ok(TextRecResult::new())
196 }
197}
198
199#[derive(Debug, Clone)]
203pub struct TextRecConfig;
204
205pub struct TextRecPredictorBuilder {
209 common: CommonBuilderConfig,
211
212 model_input_shape: Option<[usize; 3]>,
214 character_dict: Option<Vec<String>>,
216 score_thresh: Option<f32>,
218}
219
220impl_common_builder_methods!(TextRecPredictorBuilder, common);
221
222impl TextRecPredictorBuilder {
223 pub fn new() -> Self {
227 Self {
228 common: CommonBuilderConfig::new(),
229 model_input_shape: None,
230 character_dict: None,
231 score_thresh: None,
232 }
233 }
234
235 pub fn model_input_shape(mut self, shape: [usize; 3]) -> Self {
240 self.model_input_shape = Some(shape);
241 self
242 }
243
244 pub fn character_dict(mut self, character_dict: Vec<String>) -> Self {
248 self.character_dict = Some(character_dict);
249 self
250 }
251
252 pub fn score_thresh(mut self, score_thresh: f32) -> Self {
257 self.score_thresh = Some(score_thresh);
258 self
259 }
260
261 pub fn build(self, model_path: &Path) -> crate::core::OcrResult<TextRecPredictor> {
265 self.build_internal(model_path)
266 }
267
268 fn build_internal(mut self, model_path: &Path) -> crate::core::OcrResult<TextRecPredictor> {
273 if self.common.model_path.is_none() {
275 self.common = self.common.model_path(model_path.to_path_buf());
276 }
277
278 let config = TextRecPredictorConfig {
280 common: self.common,
281 model_input_shape: self.model_input_shape,
282 character_dict: self.character_dict,
283 score_thresh: self.score_thresh,
284 };
285
286 let config = config.validate_and_wrap_ocr_error()?;
288
289 let model_input_shape = config.model_input_shape.unwrap_or([3, 48, 320]);
291 let character_dict = config.character_dict.clone();
292
293 let image_reader = TRImageReader::new();
294 let resize = OCRResize::new(Some(model_input_shape), None);
295 let normalize = NormalizeImage::for_ocr_recognition()?;
296 let preprocessor = TRPreprocessor { resize, normalize };
297 let infer = OrtInfer::from_common(&config.common, model_path, None)?;
298 let inference_engine = OrtInfer3D::new(infer);
299 let decoder = CTCLabelDecode::from_string_list(character_dict.as_deref(), true, false);
300 let postprocessor = TRPostprocessor { decoder };
301
302 Ok(ModularPredictor::new(
303 image_reader,
304 preprocessor,
305 inference_engine,
306 postprocessor,
307 ))
308 }
309}
310
311impl Default for TextRecPredictorBuilder {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[cfg(test)]
322 mod tests_local {
323 use super::*;
324
325 #[test]
326 fn test_text_rec_config_defaults_and_validate() {
327 let config = TextRecPredictorConfig::new();
328 assert_eq!(config.model_input_shape, Some([3, 48, 320]));
329 assert_eq!(config.common.model_name.as_deref(), Some("crnn"));
330 assert_eq!(config.common.batch_size, Some(32));
331 assert!(config.validate().is_ok());
332 }
333 }
334
335 #[test]
336 fn test_text_rec_predictor_config_score_thresh() {
337 let config = TextRecPredictorConfig::new();
339 assert_eq!(config.score_thresh, None);
340
341 let mut config = TextRecPredictorConfig::new();
343 config.score_thresh = Some(0.5);
344 assert_eq!(config.score_thresh, Some(0.5));
345 }
346
347 #[test]
348 fn test_text_rec_predictor_builder_score_thresh() {
349 let builder = TextRecPredictorBuilder::new().score_thresh(0.7);
351
352 assert_eq!(builder.score_thresh, Some(0.7));
353 }
354}