1use crate::common_builder_methods;
10use crate::core::ImageReader as CoreImageReader;
11use crate::core::{
12 BatchData, CommonBuilderConfig, DefaultImageReader, OCRError, OrtInfer, Tensor2D, Tensor4D,
13 config::{ConfigValidator, ConfigValidatorExt},
14 get_text_line_orientation_labels,
15};
16use crate::core::{
17 GranularImageReader as GIReader, ModularPredictor, OrtInfer2D, Postprocessor as GPostprocessor,
18 Preprocessor as GPreprocessor,
19};
20
21use crate::processors::{Crop, NormalizeImage, Topk};
22use image::{DynamicImage, RgbImage};
23use std::path::Path;
24use std::sync::Arc;
25
26use crate::impl_config_new_and_with_common;
27
28#[derive(Debug, Clone)]
33pub struct TextLineClasResult {
34 pub input_path: Vec<Arc<str>>,
36 pub index: Vec<usize>,
38 pub input_img: Vec<Arc<RgbImage>>,
40 pub class_ids: Vec<Vec<usize>>,
42 pub scores: Vec<Vec<f32>>,
44 pub label_names: Vec<Vec<Arc<str>>>,
46}
47
48#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
53pub struct TextLineClasPredictorConfig {
54 pub common: CommonBuilderConfig,
56 pub topk: Option<usize>,
58 pub input_shape: Option<(u32, u32)>,
60}
61
62impl_config_new_and_with_common!(
63 TextLineClasPredictorConfig,
64 common_defaults: (Some("PP-LCNet_x0_25".to_string()), Some(1)),
65 fields: {
66 topk: None,
67 input_shape: Some((224, 224))
68 }
69);
70
71impl TextLineClasPredictorConfig {
72 pub fn validate(&self) -> Result<(), crate::core::ConfigError> {
80 ConfigValidator::validate(self)
81 }
82}
83
84impl ConfigValidator for TextLineClasPredictorConfig {
85 fn validate(&self) -> Result<(), crate::core::ConfigError> {
94 self.common.validate()?;
95
96 if let Some(topk) = self.topk {
97 self.validate_positive_usize(topk, "topk")?;
98 }
99
100 if let Some((width, height)) = self.input_shape {
101 self.validate_image_dimensions(width, height)?;
102 }
103
104 Ok(())
105 }
106
107 fn get_defaults() -> Self {
116 Self {
117 common: CommonBuilderConfig::get_defaults(),
118 topk: Some(2),
119 input_shape: Some((224, 224)),
120 }
121 }
122}
123
124impl TextLineClasResult {
125 pub fn new() -> Self {
134 Self {
135 input_path: Vec::new(),
136 index: Vec::new(),
137 input_img: Vec::new(),
138 class_ids: Vec::new(),
139 scores: Vec::new(),
140 label_names: Vec::new(),
141 }
142 }
143}
144
145impl Default for TextLineClasResult {
146 fn default() -> Self {
154 Self::new()
155 }
156}
157
158pub type TextLineClasPredictor =
163 ModularPredictor<TLImageReader, TLPreprocessor, OrtInfer2D, TLPostprocessor>;
164
165#[derive(Debug, Clone)]
166pub struct TextLineClasConfig;
167
168#[derive(Debug)]
169pub struct TLImageReader {
170 inner: DefaultImageReader,
171}
172impl TLImageReader {
173 pub fn new() -> Self {
181 Self {
182 inner: DefaultImageReader::new(),
183 }
184 }
185}
186impl Default for TLImageReader {
187 fn default() -> Self {
191 Self::new()
192 }
193}
194impl GIReader for TLImageReader {
195 fn read_images<'a>(
196 &self,
197 paths: impl Iterator<Item = &'a str>,
198 ) -> Result<Vec<RgbImage>, OCRError> {
199 self.inner.apply(paths)
200 }
201}
202
203#[derive(Debug)]
204pub struct TLPreprocessor {
205 pub input_shape: (u32, u32),
206 pub crop: Option<Crop>,
207 pub normalize: NormalizeImage,
208}
209impl GPreprocessor for TLPreprocessor {
210 type Config = TextLineClasConfig;
211 type Output = Tensor4D;
212 fn preprocess(
213 &self,
214 images: Vec<RgbImage>,
215 _config: Option<&Self::Config>,
216 ) -> Result<Self::Output, OCRError> {
217 use crate::utils::resize_images_batch;
218 let (width, height) = self.input_shape;
219 let mut batch_imgs = resize_images_batch(&images, width, height, None);
220 if let Some(crop_op) = &self.crop {
221 batch_imgs = crop_op.process_batch(&batch_imgs).map_err(|e| {
222 OCRError::post_processing("Crop operation failed during text classification", e)
223 })?;
224 }
225 let imgs_dynamic: Vec<DynamicImage> = batch_imgs
226 .iter()
227 .map(|img| DynamicImage::ImageRgb8(img.clone()))
228 .collect();
229 self.normalize.normalize_batch_to(imgs_dynamic)
230 }
231 fn preprocessing_info(&self) -> String {
232 format!(
233 "resize_to=({},{}) + crop? + normalize",
234 self.input_shape.0, self.input_shape.1
235 )
236 }
237}
238
239#[derive(Debug)]
240pub struct TLPostprocessor {
241 pub topk: usize,
242 pub topk_op: Topk,
243}
244impl GPostprocessor for TLPostprocessor {
245 type Config = TextLineClasConfig;
246 type InferenceOutput = Tensor2D;
247 type PreprocessOutput = Tensor4D;
248 type Result = TextLineClasResult;
249 fn postprocess(
250 &self,
251 output: Self::InferenceOutput,
252 _pre: Option<&Self::PreprocessOutput>,
253 batch_data: &BatchData,
254 raw_images: Vec<RgbImage>,
255 _config: Option<&Self::Config>,
256 ) -> crate::core::OcrResult<Self::Result> {
257 let predictions: Vec<Vec<f32>> = output.outer_iter().map(|row| row.to_vec()).collect();
258 let topk_result = self
259 .topk_op
260 .process(&predictions, self.topk)
261 .map_err(|e| OCRError::ConfigError { message: e })?;
262 Ok(TextLineClasResult {
263 input_path: batch_data.input_paths.clone(),
264 index: batch_data.indexes.clone(),
265 input_img: raw_images.into_iter().map(Arc::new).collect(),
266 class_ids: topk_result.indexes,
267 scores: topk_result.scores,
268 label_names: topk_result
269 .class_names
270 .unwrap_or_default()
271 .into_iter()
272 .map(|names| names.into_iter().map(Arc::from).collect())
273 .collect(),
274 })
275 }
276 fn empty_result(&self) -> crate::core::OcrResult<Self::Result> {
277 Ok(TextLineClasResult::new())
278 }
279}
280
281pub struct TextLineClasPredictorBuilder {
286 common: CommonBuilderConfig,
288
289 topk: Option<usize>,
291 input_shape: Option<(u32, u32)>,
293}
294
295impl TextLineClasPredictorBuilder {
296 pub fn new() -> Self {
305 Self {
306 common: CommonBuilderConfig::new(),
307 topk: None,
308 input_shape: None,
309 }
310 }
311
312 common_builder_methods!(common);
314
315 pub fn topk(mut self, topk: usize) -> Self {
327 self.topk = Some(topk);
328 self
329 }
330
331 pub fn input_shape(mut self, input_shape: (u32, u32)) -> Self {
343 self.input_shape = Some(input_shape);
344 self
345 }
346
347 pub fn build(self, model_path: &Path) -> crate::core::OcrResult<TextLineClasPredictor> {
360 self.build_internal(model_path)
361 }
362
363 fn build_internal(
377 mut self,
378 model_path: &Path,
379 ) -> crate::core::OcrResult<TextLineClasPredictor> {
380 if self.common.model_path.is_none() {
381 self.common = self.common.model_path(model_path.to_path_buf());
382 }
383
384 let config = TextLineClasPredictorConfig {
385 common: self.common,
386 topk: self.topk,
387 input_shape: self.input_shape,
388 };
389
390 let config = config.validate_and_wrap_ocr_error()?;
391
392 let input_shape = config.input_shape.unwrap_or((224, 224));
393 let (width, height) = input_shape;
394 let crop = Some(
395 Crop::new([width, height], crate::processors::CropMode::Center).map_err(|e| {
396 OCRError::ConfigError {
397 message: format!("Failed to create crop operation: {e}"),
398 }
399 })?,
400 );
401 let normalize = NormalizeImage::new(
402 Some(1.0 / 255.0),
403 Some(vec![0.485, 0.456, 0.406]),
404 Some(vec![0.229, 0.224, 0.225]),
405 None,
406 )?;
407 let preprocessor = TLPreprocessor {
408 input_shape,
409 crop,
410 normalize,
411 };
412 let infer_inner = OrtInfer::from_common(&config.common, model_path, None)?;
413 let inference_engine = OrtInfer2D::new(infer_inner);
414 let postprocessor = TLPostprocessor {
415 topk: config.topk.unwrap_or(2),
416 topk_op: Topk::from_class_names(get_text_line_orientation_labels()),
417 };
418 let image_reader = TLImageReader::new();
419 Ok(ModularPredictor::new(
420 image_reader,
421 preprocessor,
422 inference_engine,
423 postprocessor,
424 ))
425 }
426}
427
428impl Default for TextLineClasPredictorBuilder {
429 fn default() -> Self {
434 Self::new()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_text_line_clas_config_defaults_and_validate() {
444 let config = TextLineClasPredictorConfig::new();
445 assert_eq!(config.topk, None); assert_eq!(config.input_shape, Some((224, 224)));
448 assert_eq!(config.common.model_name.as_deref(), Some("PP-LCNet_x0_25"));
449 assert_eq!(config.common.batch_size, Some(1));
450 assert!(config.validate().is_ok());
452 }
453}