oar_ocr/pipeline/stages/
recognition.rs

1//! Text recognition processing stage processor.
2//!
3//! This module provides a refactored text recognition pipeline that separates concerns
4//! into dedicated components:
5//!
6//! ## Architecture
7//!
8//! The recognition stage has been refactored to separate mixed responsibilities:
9//!
10//! ### Grouping Strategies
11//! - **`GroupingStrategy`** trait: Defines how images are grouped for batch processing
12//! - **`ExactDimensionStrategy`**: Groups images by exact pixel dimensions
13//! - **`AspectRatioBucketingStrategy`**: Groups images by aspect ratio ranges for better batching efficiency
14//!
15//! ### Orientation Correction
16//! - **`OrientationCorrector`**: Handles text line orientation corrections separately from grouping logic
17//! - Configurable through `OrientationCorrectionConfig`
18//!
19//! ### Configuration
20//! - **`RecognitionConfig`**: New unified configuration using the separated components
21//! - **`from_legacy_config()`**: Helper method for backward compatibility
22//!
23//! ## Example Usage
24//!
25//! ```rust
26//! use oar_ocr::pipeline::stages::{
27//!     RecognitionConfig, GroupingStrategyConfig, OrientationCorrectionConfig
28//! };
29//! use oar_ocr::processors::AspectRatioBucketingConfig;
30//!
31//! // Create config with aspect ratio bucketing and orientation correction
32//! let config = RecognitionConfig {
33//!     orientation_correction: OrientationCorrectionConfig { enabled: true },
34//!     grouping_strategy: GroupingStrategyConfig::AspectRatioBucketing(
35//!         AspectRatioBucketingConfig::default()
36//!     ),
37//! };
38//!
39//! // Or create from legacy config format
40//! let legacy_config = RecognitionConfig::from_legacy_config(
41//!     true, // use_textline_orientation
42//!     Some(AspectRatioBucketingConfig::default()), // aspect_ratio_bucketing
43//! );
44//! ```
45
46use image::RgbImage;
47use serde::{Deserialize, Serialize};
48use std::sync::Arc;
49
50use crate::metrics;
51use tracing::{debug, warn};
52
53use super::extensible::{PipelineStage, StageContext, StageData, StageDependency, StageId};
54
55use super::types::StageResult;
56use crate::core::config::ConfigValidator;
57use crate::core::{OCRError, StandardPredictor};
58use crate::predictor::TextRecPredictor;
59
60#[path = "recognition_grouping.rs"]
61mod recognition_grouping;
62#[path = "recognition_orientation.rs"]
63mod recognition_orientation;
64
65pub use recognition_grouping::{GroupingStrategy, GroupingStrategyConfig, GroupingStrategyFactory};
66pub use recognition_orientation::{OrientationCorrectionConfig, OrientationCorrector};
67
68/// Result of recognition processing
69#[derive(Debug, Clone)]
70pub struct RecognitionResult {
71    /// Recognition texts in original order
72    pub rec_texts: Vec<Arc<str>>,
73    /// Recognition scores in original order
74    pub rec_scores: Vec<f32>,
75    /// Number of failed recognitions
76    pub failed_recognitions: usize,
77}
78
79/// Configuration for recognition processing
80#[derive(Debug, Clone, Serialize, Deserialize, Default)]
81pub struct RecognitionConfig {
82    /// Configuration for orientation correction
83    #[serde(default)]
84    pub orientation_correction: OrientationCorrectionConfig,
85    /// Configuration for grouping strategy
86    #[serde(default)]
87    pub grouping_strategy: GroupingStrategyConfig,
88}
89
90impl RecognitionConfig {
91    /// Create a RecognitionConfig from the old config format
92    pub fn from_legacy_config(
93        use_textline_orientation: bool,
94        aspect_ratio_bucketing: Option<crate::processors::AspectRatioBucketingConfig>,
95    ) -> Self {
96        let orientation_correction = OrientationCorrectionConfig {
97            enabled: use_textline_orientation,
98        };
99
100        let grouping_strategy = if let Some(bucketing_config) = aspect_ratio_bucketing {
101            GroupingStrategyConfig::AspectRatioBucketing(bucketing_config)
102        } else {
103            GroupingStrategyConfig::ExactDimensions
104        };
105
106        Self {
107            orientation_correction,
108            grouping_strategy,
109        }
110    }
111}
112
113impl ConfigValidator for RecognitionConfig {
114    fn validate(&self) -> Result<(), crate::core::config::ConfigError> {
115        // RecognitionConfig validation - basic validation for now
116        // Could add more specific validation for grouping strategy if needed
117        Ok(())
118    }
119
120    fn get_defaults() -> Self {
121        Self::default()
122    }
123}
124
125/// Processor for text recognition grouping and processing stage.
126///
127/// This processor encapsulates the logic for:
128/// - Grouping text images by aspect ratio buckets or exact dimensions
129/// - Applying text line orientation corrections
130/// - Batch recognition processing with consistent error handling
131/// - Collecting and ordering recognition results
132pub struct RecognitionStageProcessor;
133
134impl RecognitionStageProcessor {
135    /// Process text recognition for cropped text images.
136    ///
137    /// # Arguments
138    ///
139    /// * `cropped_images` - Vector of optional cropped images (None for failed crops)
140    /// * `text_line_orientations` - Optional orientation angles for each text region
141    /// * `recognizer` - Optional text recognizer
142    /// * `config` - Configuration for recognition processing
143    ///
144    /// # Returns
145    ///
146    /// A StageResult containing the recognition result and processing metrics
147    pub fn process_single(
148        cropped_images: Vec<Option<RgbImage>>,
149        text_line_orientations: Option<&[Option<f32>]>,
150        recognizer: Option<&TextRecPredictor>,
151        config: Option<&RecognitionConfig>,
152    ) -> Result<StageResult<RecognitionResult>, OCRError> {
153        use std::time::Instant;
154        let start_time = Instant::now();
155        let default_config = RecognitionConfig::default();
156        let config = config.unwrap_or(&default_config);
157
158        let mut failed_recognitions = 0;
159        let (rec_texts, rec_scores) = if cropped_images.is_empty() {
160            (Vec::new(), Vec::new())
161        } else if let Some(recognizer) = recognizer {
162            Self::process_recognition_groups(
163                &cropped_images,
164                text_line_orientations,
165                recognizer,
166                config,
167                &mut failed_recognitions,
168            )?
169        } else {
170            debug!("No text recognizer available, skipping recognition");
171            let empty_texts = vec![Arc::from(""); cropped_images.len()];
172            let empty_scores = vec![0.0; cropped_images.len()];
173            (empty_texts, empty_scores)
174        };
175
176        let success_count = cropped_images.len() - failed_recognitions;
177        let grouping_strategy_name = match config.grouping_strategy {
178            GroupingStrategyConfig::AspectRatioBucketing(_) => "aspect_ratio_bucketing",
179            GroupingStrategyConfig::ExactDimensions => "exact_dimensions",
180        };
181        let metrics = metrics!(success_count, failed_recognitions, start_time;
182            stage = "recognition",
183            text_regions = cropped_images.len(),
184            grouping_strategy = grouping_strategy_name
185        );
186
187        let result = RecognitionResult {
188            rec_texts,
189            rec_scores,
190            failed_recognitions,
191        };
192
193        Ok(StageResult::new(result, metrics))
194    }
195
196    /// Process recognition by grouping images and running batch recognition.
197    fn process_recognition_groups(
198        cropped_images: &[Option<RgbImage>],
199        text_line_orientations: Option<&[Option<f32>]>,
200        recognizer: &TextRecPredictor,
201        config: &RecognitionConfig,
202        failed_recognitions: &mut usize,
203    ) -> Result<(Vec<Arc<str>>, Vec<f32>), OCRError> {
204        // Prepare images for grouping (filter out None values)
205        let images_for_grouping: Vec<(usize, RgbImage)> = cropped_images
206            .iter()
207            .enumerate()
208            .filter_map(|(i, cropped_img_opt)| cropped_img_opt.as_ref().map(|img| (i, img.clone())))
209            .collect();
210
211        // Create grouping strategy and group images
212        let grouping_strategy =
213            GroupingStrategyFactory::create_strategy(&config.grouping_strategy)?;
214        let groups = grouping_strategy.group_images(images_for_grouping)?;
215
216        // Create orientation corrector
217        let orientation_corrector =
218            OrientationCorrector::new(config.orientation_correction.clone());
219
220        // Process each group
221        let mut recognition_results: Vec<(usize, Arc<str>, f32)> = Vec::new();
222        for (group_name, group) in groups {
223            Self::process_recognition_group(
224                group_name,
225                group,
226                text_line_orientations,
227                recognizer,
228                &orientation_corrector,
229                &mut recognition_results,
230                failed_recognitions,
231            )?;
232        }
233
234        // Sort results by original index and extract texts and scores
235        recognition_results.sort_by_key(|(idx, _, _)| *idx);
236        let mut rec_texts = vec![Arc::from(""); cropped_images.len()];
237        let mut rec_scores = vec![0.0; cropped_images.len()];
238
239        for (original_idx, text, score) in recognition_results {
240            if original_idx < rec_texts.len() {
241                rec_texts[original_idx] = text;
242                rec_scores[original_idx] = score;
243            }
244        }
245
246        Ok((rec_texts, rec_scores))
247    }
248
249    /// Process a single recognition group.
250    fn process_recognition_group(
251        group_name: String,
252        group: Vec<(usize, RgbImage)>,
253        text_line_orientations: Option<&[Option<f32>]>,
254        recognizer: &TextRecPredictor,
255        orientation_corrector: &OrientationCorrector,
256        recognition_results: &mut Vec<(usize, Arc<str>, f32)>,
257        failed_recognitions: &mut usize,
258    ) -> Result<(), OCRError> {
259        let (indices, mut images): (Vec<usize>, Vec<RgbImage>) = group.into_iter().unzip();
260
261        // Apply text line orientation corrections using the corrector
262        let corrections_applied =
263            orientation_corrector.apply_corrections(&mut images, &indices, text_line_orientations);
264
265        if corrections_applied > 0 {
266            debug!(
267                "Applied {} orientation corrections in group '{}'",
268                corrections_applied, group_name
269            );
270        }
271
272        debug!(
273            "Processing recognition group '{}' with {} images",
274            group_name,
275            images.len()
276        );
277
278        match recognizer.predict(images, None) {
279            Ok(result) => {
280                for (original_idx, (text, score)) in indices
281                    .into_iter()
282                    .zip(result.rec_text.iter().zip(result.rec_score.iter()))
283                {
284                    recognition_results.push((original_idx, text.clone(), *score));
285                }
286            }
287            Err(e) => {
288                *failed_recognitions += indices.len();
289
290                // Create enhanced error message using the common helper
291                let enhanced_error = OCRError::format_batch_error_message(
292                    "text recognition",
293                    &group_name,
294                    &indices,
295                    &e,
296                );
297
298                warn!(
299                    "Text recognition failed for batch '{}' of {} images (indices: {:?}): {}",
300                    group_name,
301                    indices.len(),
302                    indices,
303                    enhanced_error
304                );
305
306                for original_idx in indices {
307                    recognition_results.push((original_idx, Arc::from(""), 0.0));
308                }
309            }
310        }
311
312        Ok(())
313    }
314}
315
316/// Extensible recognition stage that implements PipelineStage trait.
317#[derive(Debug)]
318pub struct ExtensibleRecognitionStage {
319    recognizer: Option<Arc<TextRecPredictor>>,
320}
321
322impl ExtensibleRecognitionStage {
323    /// Create a new extensible recognition stage.
324    pub fn new(recognizer: Option<Arc<TextRecPredictor>>) -> Self {
325        Self { recognizer }
326    }
327}
328
329impl PipelineStage for ExtensibleRecognitionStage {
330    type Config = RecognitionConfig;
331    type Result = RecognitionResult;
332
333    fn stage_id(&self) -> StageId {
334        StageId::new("recognition")
335    }
336
337    fn stage_name(&self) -> &str {
338        "Text Recognition"
339    }
340
341    fn dependencies(&self) -> Vec<StageDependency> {
342        // Recognition depends on cropping results
343        vec![StageDependency::Requires(StageId::new("cropping"))]
344    }
345
346    fn is_enabled(&self, _context: &StageContext, _config: Option<&Self::Config>) -> bool {
347        self.recognizer.is_some()
348    }
349
350    fn process(
351        &self,
352        context: &mut StageContext,
353        _data: StageData,
354        config: Option<&Self::Config>,
355    ) -> Result<StageResult<Self::Result>, OCRError> {
356        // Get cropped images from the cropping stage
357        let cropping_result = context
358            .get_stage_result::<super::cropping::CroppingResult>(&StageId::new("cropping"))
359            .ok_or_else(|| {
360                OCRError::processing_error(
361                    crate::core::ProcessingStage::Generic,
362                    "Cropping results not found in context",
363                    crate::core::errors::SimpleError::new("Missing cropping results"),
364                )
365            })?;
366
367        // Get text line orientations if available
368        let text_line_orientations =
369            context.get_stage_result::<Vec<Option<f32>>>(&StageId::new("text_line_orientation"));
370
371        let recognition_config = config.cloned().unwrap_or_default();
372
373        let stage_result = RecognitionStageProcessor::process_single(
374            cropping_result.cropped_images.clone(),
375            text_line_orientations.map(|v| &**v),
376            self.recognizer.as_ref().map(|r| r.as_ref()),
377            Some(&recognition_config),
378        )?;
379
380        Ok(stage_result)
381    }
382
383    fn validate_config(&self, config: &Self::Config) -> Result<(), OCRError> {
384        config.validate().map_err(|e| OCRError::ConfigError {
385            message: format!("RecognitionConfig validation failed: {}", e),
386        })
387    }
388
389    fn default_config(&self) -> Self::Config {
390        RecognitionConfig::get_defaults()
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::processors::AspectRatioBucketingConfig;
398
399    #[test]
400    fn test_recognition_config_from_legacy_config() {
401        // Test with orientation enabled and aspect ratio bucketing
402        let config = RecognitionConfig::from_legacy_config(
403            true,
404            Some(AspectRatioBucketingConfig::default()),
405        );
406
407        assert!(config.orientation_correction.enabled);
408        match config.grouping_strategy {
409            GroupingStrategyConfig::AspectRatioBucketing(_) => {
410                // This is expected
411            }
412            _ => panic!("Should be AspectRatioBucketing"),
413        }
414
415        // Test with orientation disabled and no bucketing
416        let config = RecognitionConfig::from_legacy_config(false, None);
417
418        assert!(!config.orientation_correction.enabled);
419        match config.grouping_strategy {
420            GroupingStrategyConfig::ExactDimensions => {
421                // This is expected
422            }
423            _ => panic!("Should be ExactDimensions"),
424        }
425    }
426
427    #[test]
428    fn test_recognition_config_serialization() {
429        let config = RecognitionConfig::from_legacy_config(
430            true,
431            Some(AspectRatioBucketingConfig::default()),
432        );
433
434        // Test that the config can be serialized and deserialized
435        let serialized = serde_json::to_string(&config).unwrap();
436        let deserialized: RecognitionConfig = serde_json::from_str(&serialized).unwrap();
437
438        assert_eq!(
439            config.orientation_correction.enabled,
440            deserialized.orientation_correction.enabled
441        );
442
443        match (&config.grouping_strategy, &deserialized.grouping_strategy) {
444            (
445                GroupingStrategyConfig::AspectRatioBucketing(_),
446                GroupingStrategyConfig::AspectRatioBucketing(_),
447            ) => {
448                // This is expected
449            }
450            _ => panic!("Grouping strategy should match after serialization"),
451        }
452    }
453}