1use image::RgbImage;
47use serde::{Deserialize, Serialize};
48use std::sync::Arc;
49
50use crate::metrics;
51use tracing::{debug, warn};
52
53use super::extensible::{PipelineStage, StageContext, StageData, StageDependency, StageId};
54
55use super::types::StageResult;
56use crate::core::config::ConfigValidator;
57use crate::core::{OCRError, StandardPredictor};
58use crate::predictor::TextRecPredictor;
59
60#[path = "recognition_grouping.rs"]
61mod recognition_grouping;
62#[path = "recognition_orientation.rs"]
63mod recognition_orientation;
64
65pub use recognition_grouping::{GroupingStrategy, GroupingStrategyConfig, GroupingStrategyFactory};
66pub use recognition_orientation::{OrientationCorrectionConfig, OrientationCorrector};
67
68#[derive(Debug, Clone)]
70pub struct RecognitionResult {
71 pub rec_texts: Vec<Arc<str>>,
73 pub rec_scores: Vec<f32>,
75 pub failed_recognitions: usize,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, Default)]
81pub struct RecognitionConfig {
82 #[serde(default)]
84 pub orientation_correction: OrientationCorrectionConfig,
85 #[serde(default)]
87 pub grouping_strategy: GroupingStrategyConfig,
88}
89
90impl RecognitionConfig {
91 pub fn from_legacy_config(
93 use_textline_orientation: bool,
94 aspect_ratio_bucketing: Option<crate::processors::AspectRatioBucketingConfig>,
95 ) -> Self {
96 let orientation_correction = OrientationCorrectionConfig {
97 enabled: use_textline_orientation,
98 };
99
100 let grouping_strategy = if let Some(bucketing_config) = aspect_ratio_bucketing {
101 GroupingStrategyConfig::AspectRatioBucketing(bucketing_config)
102 } else {
103 GroupingStrategyConfig::ExactDimensions
104 };
105
106 Self {
107 orientation_correction,
108 grouping_strategy,
109 }
110 }
111}
112
113impl ConfigValidator for RecognitionConfig {
114 fn validate(&self) -> Result<(), crate::core::config::ConfigError> {
115 Ok(())
118 }
119
120 fn get_defaults() -> Self {
121 Self::default()
122 }
123}
124
125pub struct RecognitionStageProcessor;
133
134impl RecognitionStageProcessor {
135 pub fn process_single(
148 cropped_images: Vec<Option<RgbImage>>,
149 text_line_orientations: Option<&[Option<f32>]>,
150 recognizer: Option<&TextRecPredictor>,
151 config: Option<&RecognitionConfig>,
152 ) -> Result<StageResult<RecognitionResult>, OCRError> {
153 use std::time::Instant;
154 let start_time = Instant::now();
155 let default_config = RecognitionConfig::default();
156 let config = config.unwrap_or(&default_config);
157
158 let mut failed_recognitions = 0;
159 let (rec_texts, rec_scores) = if cropped_images.is_empty() {
160 (Vec::new(), Vec::new())
161 } else if let Some(recognizer) = recognizer {
162 Self::process_recognition_groups(
163 &cropped_images,
164 text_line_orientations,
165 recognizer,
166 config,
167 &mut failed_recognitions,
168 )?
169 } else {
170 debug!("No text recognizer available, skipping recognition");
171 let empty_texts = vec![Arc::from(""); cropped_images.len()];
172 let empty_scores = vec![0.0; cropped_images.len()];
173 (empty_texts, empty_scores)
174 };
175
176 let success_count = cropped_images.len() - failed_recognitions;
177 let grouping_strategy_name = match config.grouping_strategy {
178 GroupingStrategyConfig::AspectRatioBucketing(_) => "aspect_ratio_bucketing",
179 GroupingStrategyConfig::ExactDimensions => "exact_dimensions",
180 };
181 let metrics = metrics!(success_count, failed_recognitions, start_time;
182 stage = "recognition",
183 text_regions = cropped_images.len(),
184 grouping_strategy = grouping_strategy_name
185 );
186
187 let result = RecognitionResult {
188 rec_texts,
189 rec_scores,
190 failed_recognitions,
191 };
192
193 Ok(StageResult::new(result, metrics))
194 }
195
196 fn process_recognition_groups(
198 cropped_images: &[Option<RgbImage>],
199 text_line_orientations: Option<&[Option<f32>]>,
200 recognizer: &TextRecPredictor,
201 config: &RecognitionConfig,
202 failed_recognitions: &mut usize,
203 ) -> Result<(Vec<Arc<str>>, Vec<f32>), OCRError> {
204 let images_for_grouping: Vec<(usize, RgbImage)> = cropped_images
206 .iter()
207 .enumerate()
208 .filter_map(|(i, cropped_img_opt)| cropped_img_opt.as_ref().map(|img| (i, img.clone())))
209 .collect();
210
211 let grouping_strategy =
213 GroupingStrategyFactory::create_strategy(&config.grouping_strategy)?;
214 let groups = grouping_strategy.group_images(images_for_grouping)?;
215
216 let orientation_corrector =
218 OrientationCorrector::new(config.orientation_correction.clone());
219
220 let mut recognition_results: Vec<(usize, Arc<str>, f32)> = Vec::new();
222 for (group_name, group) in groups {
223 Self::process_recognition_group(
224 group_name,
225 group,
226 text_line_orientations,
227 recognizer,
228 &orientation_corrector,
229 &mut recognition_results,
230 failed_recognitions,
231 )?;
232 }
233
234 recognition_results.sort_by_key(|(idx, _, _)| *idx);
236 let mut rec_texts = vec![Arc::from(""); cropped_images.len()];
237 let mut rec_scores = vec![0.0; cropped_images.len()];
238
239 for (original_idx, text, score) in recognition_results {
240 if original_idx < rec_texts.len() {
241 rec_texts[original_idx] = text;
242 rec_scores[original_idx] = score;
243 }
244 }
245
246 Ok((rec_texts, rec_scores))
247 }
248
249 fn process_recognition_group(
251 group_name: String,
252 group: Vec<(usize, RgbImage)>,
253 text_line_orientations: Option<&[Option<f32>]>,
254 recognizer: &TextRecPredictor,
255 orientation_corrector: &OrientationCorrector,
256 recognition_results: &mut Vec<(usize, Arc<str>, f32)>,
257 failed_recognitions: &mut usize,
258 ) -> Result<(), OCRError> {
259 let (indices, mut images): (Vec<usize>, Vec<RgbImage>) = group.into_iter().unzip();
260
261 let corrections_applied =
263 orientation_corrector.apply_corrections(&mut images, &indices, text_line_orientations);
264
265 if corrections_applied > 0 {
266 debug!(
267 "Applied {} orientation corrections in group '{}'",
268 corrections_applied, group_name
269 );
270 }
271
272 debug!(
273 "Processing recognition group '{}' with {} images",
274 group_name,
275 images.len()
276 );
277
278 match recognizer.predict(images, None) {
279 Ok(result) => {
280 for (original_idx, (text, score)) in indices
281 .into_iter()
282 .zip(result.rec_text.iter().zip(result.rec_score.iter()))
283 {
284 recognition_results.push((original_idx, text.clone(), *score));
285 }
286 }
287 Err(e) => {
288 *failed_recognitions += indices.len();
289
290 let enhanced_error = OCRError::format_batch_error_message(
292 "text recognition",
293 &group_name,
294 &indices,
295 &e,
296 );
297
298 warn!(
299 "Text recognition failed for batch '{}' of {} images (indices: {:?}): {}",
300 group_name,
301 indices.len(),
302 indices,
303 enhanced_error
304 );
305
306 for original_idx in indices {
307 recognition_results.push((original_idx, Arc::from(""), 0.0));
308 }
309 }
310 }
311
312 Ok(())
313 }
314}
315
316#[derive(Debug)]
318pub struct ExtensibleRecognitionStage {
319 recognizer: Option<Arc<TextRecPredictor>>,
320}
321
322impl ExtensibleRecognitionStage {
323 pub fn new(recognizer: Option<Arc<TextRecPredictor>>) -> Self {
325 Self { recognizer }
326 }
327}
328
329impl PipelineStage for ExtensibleRecognitionStage {
330 type Config = RecognitionConfig;
331 type Result = RecognitionResult;
332
333 fn stage_id(&self) -> StageId {
334 StageId::new("recognition")
335 }
336
337 fn stage_name(&self) -> &str {
338 "Text Recognition"
339 }
340
341 fn dependencies(&self) -> Vec<StageDependency> {
342 vec![StageDependency::Requires(StageId::new("cropping"))]
344 }
345
346 fn is_enabled(&self, _context: &StageContext, _config: Option<&Self::Config>) -> bool {
347 self.recognizer.is_some()
348 }
349
350 fn process(
351 &self,
352 context: &mut StageContext,
353 _data: StageData,
354 config: Option<&Self::Config>,
355 ) -> Result<StageResult<Self::Result>, OCRError> {
356 let cropping_result = context
358 .get_stage_result::<super::cropping::CroppingResult>(&StageId::new("cropping"))
359 .ok_or_else(|| {
360 OCRError::processing_error(
361 crate::core::ProcessingStage::Generic,
362 "Cropping results not found in context",
363 crate::core::errors::SimpleError::new("Missing cropping results"),
364 )
365 })?;
366
367 let text_line_orientations =
369 context.get_stage_result::<Vec<Option<f32>>>(&StageId::new("text_line_orientation"));
370
371 let recognition_config = config.cloned().unwrap_or_default();
372
373 let stage_result = RecognitionStageProcessor::process_single(
374 cropping_result.cropped_images.clone(),
375 text_line_orientations.map(|v| &**v),
376 self.recognizer.as_ref().map(|r| r.as_ref()),
377 Some(&recognition_config),
378 )?;
379
380 Ok(stage_result)
381 }
382
383 fn validate_config(&self, config: &Self::Config) -> Result<(), OCRError> {
384 config.validate().map_err(|e| OCRError::ConfigError {
385 message: format!("RecognitionConfig validation failed: {}", e),
386 })
387 }
388
389 fn default_config(&self) -> Self::Config {
390 RecognitionConfig::get_defaults()
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::processors::AspectRatioBucketingConfig;
398
399 #[test]
400 fn test_recognition_config_from_legacy_config() {
401 let config = RecognitionConfig::from_legacy_config(
403 true,
404 Some(AspectRatioBucketingConfig::default()),
405 );
406
407 assert!(config.orientation_correction.enabled);
408 match config.grouping_strategy {
409 GroupingStrategyConfig::AspectRatioBucketing(_) => {
410 }
412 _ => panic!("Should be AspectRatioBucketing"),
413 }
414
415 let config = RecognitionConfig::from_legacy_config(false, None);
417
418 assert!(!config.orientation_correction.enabled);
419 match config.grouping_strategy {
420 GroupingStrategyConfig::ExactDimensions => {
421 }
423 _ => panic!("Should be ExactDimensions"),
424 }
425 }
426
427 #[test]
428 fn test_recognition_config_serialization() {
429 let config = RecognitionConfig::from_legacy_config(
430 true,
431 Some(AspectRatioBucketingConfig::default()),
432 );
433
434 let serialized = serde_json::to_string(&config).unwrap();
436 let deserialized: RecognitionConfig = serde_json::from_str(&serialized).unwrap();
437
438 assert_eq!(
439 config.orientation_correction.enabled,
440 deserialized.orientation_correction.enabled
441 );
442
443 match (&config.grouping_strategy, &deserialized.grouping_strategy) {
444 (
445 GroupingStrategyConfig::AspectRatioBucketing(_),
446 GroupingStrategyConfig::AspectRatioBucketing(_),
447 ) => {
448 }
450 _ => panic!("Grouping strategy should match after serialization"),
451 }
452 }
453}