Skip to main content

ocr_rs/
rec.rs

1//! Text Recognition Model
2//!
3//! Provides text recognition functionality based on PaddleOCR recognition models
4
5use image::DynamicImage;
6use ndarray::ArrayD;
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::preprocess::{preprocess_for_rec, NormalizeParams};
12
13/// Recognition result
14#[derive(Debug, Clone)]
15pub struct RecognitionResult {
16    /// Recognized text
17    pub text: String,
18    /// Confidence score (0.0 - 1.0)
19    pub confidence: f32,
20    /// Confidence score for each character
21    pub char_scores: Vec<(char, f32)>,
22}
23
24impl RecognitionResult {
25    /// Create a new recognition result
26    pub fn new(text: String, confidence: f32, char_scores: Vec<(char, f32)>) -> Self {
27        Self {
28            text,
29            confidence,
30            char_scores,
31        }
32    }
33
34    /// Check if the result is valid (confidence above threshold)
35    pub fn is_valid(&self, threshold: f32) -> bool {
36        self.confidence >= threshold
37    }
38}
39
40/// Recognition options
41#[derive(Debug, Clone)]
42pub struct RecOptions {
43    /// Target height (recognition model input height)
44    pub target_height: u32,
45    /// Minimum confidence threshold (characters below this value will be filtered)
46    pub min_score: f32,
47    /// Minimum confidence threshold for punctuation
48    pub punct_min_score: f32,
49    /// Batch size
50    pub batch_size: usize,
51    /// Whether to enable batch processing
52    pub enable_batch: bool,
53}
54
55impl Default for RecOptions {
56    fn default() -> Self {
57        Self {
58            target_height: 48,
59            min_score: 0.3, // Lower threshold, model output is raw logit
60            punct_min_score: 0.1,
61            batch_size: 8,
62            enable_batch: true,
63        }
64    }
65}
66
67impl RecOptions {
68    /// Create new recognition options
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    /// Set target height
74    pub fn with_target_height(mut self, height: u32) -> Self {
75        self.target_height = height;
76        self
77    }
78
79    /// Set minimum confidence
80    pub fn with_min_score(mut self, score: f32) -> Self {
81        self.min_score = score;
82        self
83    }
84
85    /// Set punctuation minimum confidence
86    pub fn with_punct_min_score(mut self, score: f32) -> Self {
87        self.punct_min_score = score;
88        self
89    }
90
91    /// Set batch size
92    pub fn with_batch_size(mut self, size: usize) -> Self {
93        self.batch_size = size;
94        self
95    }
96
97    /// Enable/disable batch processing
98    pub fn with_batch(mut self, enable: bool) -> Self {
99        self.enable_batch = enable;
100        self
101    }
102}
103
104/// Text recognition model
105pub struct RecModel {
106    engine: InferenceEngine,
107    /// Character set (index to character mapping)
108    charset: Vec<char>,
109    options: RecOptions,
110    normalize_params: NormalizeParams,
111}
112
113/// Common punctuation marks
114const PUNCTUATIONS: [char; 49] = [
115    ',', '.', '!', '?', ';', ':', '"', '\'', '(', ')', '[', ']', '{', '}', '-', '_', '/', '\\',
116    '|', '@', '#', '$', '%', '&', '*', '+', '=', '~', ',', '。', '!', '?', ';', ':', '、',
117    '「', '」', '『', '』', '(', ')', '【', '】', '《', '》', '—', '…', '·', '~',
118];
119
120impl RecModel {
121    /// Create recognizer from model file and charset file
122    ///
123    /// # Parameters
124    /// - `model_path`: Model file path (.mnn format)
125    /// - `charset_path`: Charset file path (one character per line)
126    /// - `config`: Optional inference config
127    pub fn from_file(
128        model_path: impl AsRef<Path>,
129        charset_path: impl AsRef<Path>,
130        config: Option<InferenceConfig>,
131    ) -> OcrResult<Self> {
132        let engine = InferenceEngine::from_file(model_path, config)?;
133        let charset = Self::load_charset_from_file(charset_path)?;
134
135        Ok(Self {
136            engine,
137            charset,
138            options: RecOptions::default(),
139            normalize_params: NormalizeParams::paddle_rec(),
140        })
141    }
142
143    /// Create recognizer from model bytes and charset file
144    pub fn from_bytes(
145        model_bytes: &[u8],
146        charset_path: impl AsRef<Path>,
147        config: Option<InferenceConfig>,
148    ) -> OcrResult<Self> {
149        let engine = InferenceEngine::from_buffer(model_bytes, config)?;
150        let charset = Self::load_charset_from_file(charset_path)?;
151
152        Ok(Self {
153            engine,
154            charset,
155            options: RecOptions::default(),
156            normalize_params: NormalizeParams::paddle_rec(),
157        })
158    }
159
160    /// Create recognizer from model bytes and charset bytes
161    pub fn from_bytes_with_charset(
162        model_bytes: &[u8],
163        charset_bytes: &[u8],
164        config: Option<InferenceConfig>,
165    ) -> OcrResult<Self> {
166        let engine = InferenceEngine::from_buffer(model_bytes, config)?;
167        let charset = Self::parse_charset(charset_bytes)?;
168
169        Ok(Self {
170            engine,
171            charset,
172            options: RecOptions::default(),
173            normalize_params: NormalizeParams::paddle_rec(),
174        })
175    }
176
177    /// Load charset from file
178    fn load_charset_from_file(path: impl AsRef<Path>) -> OcrResult<Vec<char>> {
179        let content = std::fs::read_to_string(path)?;
180        Self::parse_charset(content.as_bytes())
181    }
182
183    /// Parse charset data
184    fn parse_charset(data: &[u8]) -> OcrResult<Vec<char>> {
185        let content = std::str::from_utf8(data)
186            .map_err(|e| OcrError::CharsetError(format!("UTF-8 decode error: {}", e)))?;
187
188        // Charset format: one character per line
189        // Add space at beginning and end as blank and padding
190        let mut charset: Vec<char> = vec![' ']; // blank token at start
191
192        for ch in content.chars() {
193            if ch != '\n' && ch != '\r' {
194                charset.push(ch);
195            }
196        }
197
198        charset.push(' '); // padding token at end
199
200        if charset.len() < 3 {
201            return Err(OcrError::CharsetError("Charset too small".to_string()));
202        }
203
204        Ok(charset)
205    }
206
207    /// Set recognition options
208    pub fn with_options(mut self, options: RecOptions) -> Self {
209        self.options = options;
210        self
211    }
212
213    /// Get current recognition options
214    pub fn options(&self) -> &RecOptions {
215        &self.options
216    }
217
218    /// Modify recognition options
219    pub fn options_mut(&mut self) -> &mut RecOptions {
220        &mut self.options
221    }
222
223    /// Get charset size
224    pub fn charset_size(&self) -> usize {
225        self.charset.len()
226    }
227
228    /// Recognize a single image
229    ///
230    /// # Parameters
231    /// - `image`: Input image (text line image)
232    ///
233    /// # Returns
234    /// Recognition result
235    pub fn recognize(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
236        // Preprocess
237        let input = preprocess_for_rec(image, self.options.target_height, &self.normalize_params)?;
238
239        // Inference (using dynamic shape)
240        let output = self.engine.run_dynamic(input.view().into_dyn())?;
241
242        // Decode
243        self.decode_output(&output)
244    }
245
246    /// Recognize a single image, return text only
247    pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<String> {
248        let result = self.recognize(image)?;
249        Ok(result.text)
250    }
251
252    /// Batch recognize images
253    ///
254    /// # Parameters
255    /// - `images`: List of input images
256    ///
257    /// # Returns
258    /// List of recognition results
259    pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
260        if images.is_empty() {
261            return Ok(Vec::new());
262        }
263
264        // For small number of images, process individually
265        if images.len() <= 2 || !self.options.enable_batch {
266            return images.iter().map(|img| self.recognize(img)).collect();
267        }
268
269        // Batch processing
270        let mut results = Vec::with_capacity(images.len());
271
272        for chunk in images.chunks(self.options.batch_size) {
273            let batch_results = self.recognize_batch_internal(chunk)?;
274            results.extend(batch_results);
275        }
276
277        Ok(results)
278    }
279
280    /// Batch recognize images (borrowed version, avoid cloning)
281    ///
282    /// # Parameters
283    /// - `images`: List of input image references
284    ///
285    /// # Returns
286    /// List of recognition results
287    pub fn recognize_batch_ref(
288        &self,
289        images: &[&DynamicImage],
290    ) -> OcrResult<Vec<RecognitionResult>> {
291        if images.is_empty() {
292            return Ok(Vec::new());
293        }
294
295        // For small number of images, process individually
296        if images.len() <= 2 || !self.options.enable_batch {
297            return images.iter().map(|img| self.recognize(img)).collect();
298        }
299
300        // Batch processing
301        let mut results = Vec::with_capacity(images.len());
302
303        for chunk in images.chunks(self.options.batch_size) {
304            // Dereference and convert to Vec<DynamicImage>
305            let chunk_owned: Vec<DynamicImage> = chunk.iter().map(|img| (*img).clone()).collect();
306            let batch_results = self.recognize_batch_internal(&chunk_owned)?;
307            results.extend(batch_results);
308        }
309
310        Ok(results)
311    }
312
313    /// Internal batch recognition
314    fn recognize_batch_internal(
315        &self,
316        images: &[DynamicImage],
317    ) -> OcrResult<Vec<RecognitionResult>> {
318        if images.is_empty() {
319            return Ok(Vec::new());
320        }
321
322        // If only one image, process individually
323        if images.len() == 1 {
324            return Ok(vec![self.recognize(&images[0])?]);
325        }
326
327        // Batch preprocessing
328        let batch_input = crate::preprocess::preprocess_batch_for_rec(
329            images,
330            self.options.target_height,
331            &self.normalize_params,
332        )?;
333
334        // Batch inference
335        let batch_output = self.engine.run_dynamic(batch_input.view().into_dyn())?;
336
337        // Decode output for each sample
338        let shape = batch_output.shape();
339        if shape.len() != 3 {
340            return Err(OcrError::PostprocessError(format!(
341                "Batch inference output shape error: {:?}",
342                shape
343            )));
344        }
345
346        let batch_size = shape[0];
347        let mut results = Vec::with_capacity(batch_size);
348
349        for i in 0..batch_size {
350            // Extract output for single sample
351            let sample_output = batch_output.slice(ndarray::s![i, .., ..]).to_owned();
352            let sample_output_dyn = sample_output.into_dyn();
353            let result = self.decode_output(&sample_output_dyn)?;
354            results.push(result);
355        }
356
357        Ok(results)
358    }
359
360    /// Decode model output
361    fn decode_output(&self, output: &ArrayD<f32>) -> OcrResult<RecognitionResult> {
362        let shape = output.shape();
363
364        // Output shape should be [batch, seq_len, num_classes] or [seq_len, num_classes]
365        let (seq_len, num_classes) = if shape.len() == 3 {
366            (shape[1], shape[2])
367        } else if shape.len() == 2 {
368            (shape[0], shape[1])
369        } else {
370            return Err(OcrError::PostprocessError(format!(
371                "Invalid output shape: {:?}",
372                shape
373            )));
374        };
375
376        let output_data: Vec<f32> = output.iter().cloned().collect();
377
378        // CTC decoding
379        let mut char_scores = Vec::new();
380        let mut prev_idx = 0usize;
381
382        for t in 0..seq_len {
383            // Find character with maximum probability at current time step
384            let start = t * num_classes;
385            let end = start + num_classes;
386            let probs = &output_data[start..end];
387
388            let (max_idx, &max_prob) = probs
389                .iter()
390                .enumerate()
391                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
392                .ok_or_else(|| {
393                    OcrError::PostprocessError("Empty probability slice in CTC decoding".into())
394                })?;
395
396            // CTC decoding rule: skip blank (index 0) and duplicate characters
397            if max_idx != 0 && max_idx != prev_idx {
398                if max_idx < self.charset.len() {
399                    let ch = self.charset[max_idx];
400
401                    // Use raw logit value as confidence (model output is already softmax probability)
402                    // For large character sets, softmax scores can be very small, so use max_prob directly
403                    let score = max_prob;
404
405                    // Only filter out very low confidence characters
406                    let threshold = if Self::is_punctuation(ch) {
407                        self.options.punct_min_score
408                    } else {
409                        self.options.min_score
410                    };
411
412                    if score >= threshold {
413                        char_scores.push((ch, score));
414                    }
415                }
416            }
417
418            prev_idx = max_idx;
419        }
420
421        // Calculate average confidence
422        let confidence = if char_scores.is_empty() {
423            0.0
424        } else {
425            char_scores.iter().map(|(_, s)| s).sum::<f32>() / char_scores.len() as f32
426        };
427
428        // Extract text
429        let text: String = char_scores.iter().map(|(ch, _)| ch).collect();
430
431        Ok(RecognitionResult::new(text, confidence, char_scores))
432    }
433
434    /// Check if character is punctuation
435    fn is_punctuation(ch: char) -> bool {
436        PUNCTUATIONS.contains(&ch)
437    }
438}
439
440/// Low-level recognition API
441impl RecModel {
442    /// Raw inference interface
443    ///
444    /// Execute model inference directly without preprocessing and postprocessing
445    ///
446    /// # Parameters
447    /// - `input`: Preprocessed input tensor [1, 3, H, W]
448    ///
449    /// # Returns
450    /// Model raw output
451    pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
452        Ok(self.engine.run_dynamic(input)?)
453    }
454
455    /// Get model input shape
456    pub fn input_shape(&self) -> &[usize] {
457        self.engine.input_shape()
458    }
459
460    /// Get model output shape
461    pub fn output_shape(&self) -> &[usize] {
462        self.engine.output_shape()
463    }
464
465    /// Get charset
466    pub fn charset(&self) -> &[char] {
467        &self.charset
468    }
469
470    /// Get character by index
471    pub fn get_char(&self, index: usize) -> Option<char> {
472        self.charset.get(index).copied()
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_rec_options_default() {
482        let opts = RecOptions::default();
483        assert_eq!(opts.target_height, 48);
484        assert_eq!(opts.min_score, 0.3);
485        assert_eq!(opts.punct_min_score, 0.1);
486        assert_eq!(opts.batch_size, 8);
487        assert!(opts.enable_batch);
488    }
489
490    #[test]
491    fn test_rec_options_builder() {
492        let opts = RecOptions::new()
493            .with_target_height(32)
494            .with_min_score(0.6)
495            .with_punct_min_score(0.2)
496            .with_batch_size(16)
497            .with_batch(false);
498
499        assert_eq!(opts.target_height, 32);
500        assert_eq!(opts.min_score, 0.6);
501        assert_eq!(opts.punct_min_score, 0.2);
502        assert_eq!(opts.batch_size, 16);
503        assert!(!opts.enable_batch);
504    }
505
506    #[test]
507    fn test_recognition_result_new() {
508        let char_scores = vec![
509            ('H', 0.99),
510            ('e', 0.94),
511            ('l', 0.93),
512            ('l', 0.95),
513            ('o', 0.94),
514        ];
515        let result = RecognitionResult::new("Hello".to_string(), 0.95, char_scores.clone());
516
517        assert_eq!(result.text, "Hello");
518        assert_eq!(result.confidence, 0.95);
519        assert_eq!(result.char_scores.len(), 5);
520        assert_eq!(result.char_scores[0].0, 'H');
521        assert_eq!(result.char_scores[0].1, 0.99);
522    }
523
524    #[test]
525    fn test_recognition_result_is_valid() {
526        let result = RecognitionResult::new(
527            "Hello".to_string(),
528            0.95,
529            vec![
530                ('H', 0.99),
531                ('e', 0.94),
532                ('l', 0.93),
533                ('l', 0.95),
534                ('o', 0.94),
535            ],
536        );
537
538        assert!(result.is_valid(0.9));
539        assert!(result.is_valid(0.95));
540        assert!(!result.is_valid(0.96));
541        assert!(!result.is_valid(0.99));
542    }
543
544    #[test]
545    fn test_recognition_result_empty() {
546        let result = RecognitionResult::new(String::new(), 0.0, vec![]);
547
548        assert!(result.text.is_empty());
549        assert_eq!(result.confidence, 0.0);
550        assert!(!result.is_valid(0.1));
551    }
552
553    #[test]
554    fn test_is_punctuation_common() {
555        // English punctuation
556        assert!(RecModel::is_punctuation(','));
557        assert!(RecModel::is_punctuation('.'));
558        assert!(RecModel::is_punctuation('!'));
559        assert!(RecModel::is_punctuation('?'));
560        assert!(RecModel::is_punctuation(';'));
561        assert!(RecModel::is_punctuation(':'));
562        assert!(RecModel::is_punctuation('"'));
563        assert!(RecModel::is_punctuation('\''));
564    }
565
566    #[test]
567    fn test_is_punctuation_chinese() {
568        // Chinese punctuation
569        assert!(RecModel::is_punctuation(','));
570        assert!(RecModel::is_punctuation('。'));
571        assert!(RecModel::is_punctuation('!'));
572        assert!(RecModel::is_punctuation('?'));
573        assert!(RecModel::is_punctuation(';'));
574        assert!(RecModel::is_punctuation(':'));
575        assert!(RecModel::is_punctuation('、'));
576        assert!(RecModel::is_punctuation('—'));
577        assert!(RecModel::is_punctuation('…'));
578    }
579
580    #[test]
581    fn test_is_punctuation_brackets() {
582        assert!(RecModel::is_punctuation('('));
583        assert!(RecModel::is_punctuation(')'));
584        assert!(RecModel::is_punctuation('['));
585        assert!(RecModel::is_punctuation(']'));
586        assert!(RecModel::is_punctuation('{'));
587        assert!(RecModel::is_punctuation('}'));
588        assert!(RecModel::is_punctuation('「'));
589        assert!(RecModel::is_punctuation('」'));
590        assert!(RecModel::is_punctuation('《'));
591        assert!(RecModel::is_punctuation('》'));
592    }
593
594    #[test]
595    fn test_is_punctuation_false() {
596        // Non-punctuation characters
597        assert!(!RecModel::is_punctuation('A'));
598        assert!(!RecModel::is_punctuation('z'));
599        assert!(!RecModel::is_punctuation('0'));
600        assert!(!RecModel::is_punctuation('中'));
601        assert!(!RecModel::is_punctuation('文'));
602        assert!(!RecModel::is_punctuation(' '));
603    }
604}