Skip to main content

ocr_rs/
rec.rs

1//! Text Recognition Model
2//!
3//! Provides text recognition functionality based on PaddleOCR recognition models
4
5use image::DynamicImage;
6use ndarray::ArrayD;
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::preprocess::{preprocess_for_rec, NormalizeParams};
12
13/// Recognition result
14#[derive(Debug, Clone)]
15pub struct RecognitionResult {
16    /// Recognized text
17    pub text: String,
18    /// Confidence score (0.0 - 1.0)
19    pub confidence: f32,
20    /// Confidence score for each character
21    pub char_scores: Vec<(char, f32)>,
22}
23
24impl RecognitionResult {
25    /// Create a new recognition result
26    pub fn new(text: String, confidence: f32, char_scores: Vec<(char, f32)>) -> Self {
27        Self {
28            text,
29            confidence,
30            char_scores,
31        }
32    }
33
34    /// Check if the result is valid (confidence above threshold)
35    pub fn is_valid(&self, threshold: f32) -> bool {
36        self.confidence >= threshold
37    }
38}
39
40/// Recognition options
41#[derive(Debug, Clone)]
42pub struct RecOptions {
43    /// Target height (recognition model input height)
44    pub target_height: u32,
45    /// Minimum confidence threshold (characters below this value will be filtered)
46    pub min_score: f32,
47    /// Minimum confidence threshold for punctuation
48    pub punct_min_score: f32,
49    /// Batch size
50    pub batch_size: usize,
51    /// Whether to enable batch processing
52    pub enable_batch: bool,
53}
54
55impl Default for RecOptions {
56    fn default() -> Self {
57        Self {
58            target_height: 48,
59            min_score: 0.3, // Lower threshold, model output is raw logit
60            punct_min_score: 0.1,
61            batch_size: 8,
62            enable_batch: true,
63        }
64    }
65}
66
67impl RecOptions {
68    /// Create new recognition options
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    /// Set target height
74    pub fn with_target_height(mut self, height: u32) -> Self {
75        self.target_height = height;
76        self
77    }
78
79    /// Set minimum confidence
80    pub fn with_min_score(mut self, score: f32) -> Self {
81        self.min_score = score;
82        self
83    }
84
85    /// Set punctuation minimum confidence
86    pub fn with_punct_min_score(mut self, score: f32) -> Self {
87        self.punct_min_score = score;
88        self
89    }
90
91    /// Set batch size
92    pub fn with_batch_size(mut self, size: usize) -> Self {
93        self.batch_size = size;
94        self
95    }
96
97    /// Enable/disable batch processing
98    pub fn with_batch(mut self, enable: bool) -> Self {
99        self.enable_batch = enable;
100        self
101    }
102}
103
104/// Text recognition model
105pub struct RecModel {
106    engine: InferenceEngine,
107    /// Character set (index to character mapping)
108    charset: Vec<char>,
109    options: RecOptions,
110    normalize_params: NormalizeParams,
111}
112
113/// Common punctuation marks
114const PUNCTUATIONS: [char; 49] = [
115    ',', '.', '!', '?', ';', ':', '"', '\'', '(', ')', '[', ']', '{', '}', '-', '_', '/', '\\',
116    '|', '@', '#', '$', '%', '&', '*', '+', '=', '~', ',', '。', '!', '?', ';', ':', '、',
117    '「', '」', '『', '』', '(', ')', '【', '】', '《', '》', '—', '…', '·', '~',
118];
119
120impl RecModel {
121    /// Create recognizer from model file and charset file
122    ///
123    /// # Parameters
124    /// - `model_path`: Model file path (.mnn format)
125    /// - `charset_path`: Charset file path (one character per line)
126    /// - `config`: Optional inference config
127    pub fn from_file(
128        model_path: impl AsRef<Path>,
129        charset_path: impl AsRef<Path>,
130        config: Option<InferenceConfig>,
131    ) -> OcrResult<Self> {
132        let engine = InferenceEngine::from_file(model_path, config)?;
133        let charset = Self::load_charset_from_file(charset_path)?;
134
135        Ok(Self {
136            engine,
137            charset,
138            options: RecOptions::default(),
139            normalize_params: NormalizeParams::paddle_rec(),
140        })
141    }
142
143    /// Create recognizer from model bytes and charset file
144    pub fn from_bytes(
145        model_bytes: &[u8],
146        charset_path: impl AsRef<Path>,
147        config: Option<InferenceConfig>,
148    ) -> OcrResult<Self> {
149        let engine = InferenceEngine::from_buffer(model_bytes, config)?;
150        let charset = Self::load_charset_from_file(charset_path)?;
151
152        Ok(Self {
153            engine,
154            charset,
155            options: RecOptions::default(),
156            normalize_params: NormalizeParams::paddle_rec(),
157        })
158    }
159
160    /// Create recognizer from model bytes and charset bytes
161    pub fn from_bytes_with_charset(
162        model_bytes: &[u8],
163        charset_bytes: &[u8],
164        config: Option<InferenceConfig>,
165    ) -> OcrResult<Self> {
166        let engine = InferenceEngine::from_buffer(model_bytes, config)?;
167        let charset = Self::parse_charset(charset_bytes)?;
168
169        Ok(Self {
170            engine,
171            charset,
172            options: RecOptions::default(),
173            normalize_params: NormalizeParams::paddle_rec(),
174        })
175    }
176
177    /// Load charset from file
178    fn load_charset_from_file(path: impl AsRef<Path>) -> OcrResult<Vec<char>> {
179        let content = std::fs::read_to_string(path)?;
180        Self::parse_charset(content.as_bytes())
181    }
182
183    /// Parse charset data
184    fn parse_charset(data: &[u8]) -> OcrResult<Vec<char>> {
185        let content = std::str::from_utf8(data)
186            .map_err(|e| OcrError::CharsetError(format!("UTF-8 decode error: {}", e)))?;
187
188        // Charset format: one character per line
189        // Add space at beginning and end as blank and padding
190        let mut charset: Vec<char> = vec![' ']; // blank token at start
191
192        for ch in content.chars() {
193            if ch != '\n' && ch != '\r' {
194                charset.push(ch);
195            }
196        }
197
198        charset.push(' '); // padding token at end
199
200        if charset.len() < 3 {
201            return Err(OcrError::CharsetError("Charset too small".to_string()));
202        }
203
204        Ok(charset)
205    }
206
207    /// Set recognition options
208    pub fn with_options(mut self, options: RecOptions) -> Self {
209        self.options = options;
210        self
211    }
212
213    /// Get current recognition options
214    pub fn options(&self) -> &RecOptions {
215        &self.options
216    }
217
218    /// Modify recognition options
219    pub fn options_mut(&mut self) -> &mut RecOptions {
220        &mut self.options
221    }
222
223    /// Get charset size
224    pub fn charset_size(&self) -> usize {
225        self.charset.len()
226    }
227
228    /// Recognize a single image
229    ///
230    /// # Parameters
231    /// - `image`: Input image (text line image)
232    ///
233    /// # Returns
234    /// Recognition result
235    pub fn recognize(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
236        // Preprocess
237        let input = preprocess_for_rec(image, self.options.target_height, &self.normalize_params);
238
239        // Inference (using dynamic shape)
240        let output = self.engine.run_dynamic(input.view().into_dyn())?;
241
242        // Decode
243        self.decode_output(&output)
244    }
245
246    /// Recognize a single image, return text only
247    pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<String> {
248        let result = self.recognize(image)?;
249        Ok(result.text)
250    }
251
252    /// Batch recognize images
253    ///
254    /// # Parameters
255    /// - `images`: List of input images
256    ///
257    /// # Returns
258    /// List of recognition results
259    pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
260        if images.is_empty() {
261            return Ok(Vec::new());
262        }
263
264        // For small number of images, process individually
265        if images.len() <= 2 || !self.options.enable_batch {
266            return images.iter().map(|img| self.recognize(img)).collect();
267        }
268
269        // Batch processing
270        let mut results = Vec::with_capacity(images.len());
271
272        for chunk in images.chunks(self.options.batch_size) {
273            let batch_results = self.recognize_batch_internal(chunk)?;
274            results.extend(batch_results);
275        }
276
277        Ok(results)
278    }
279
280    /// Batch recognize images (borrowed version, avoid cloning)
281    ///
282    /// # Parameters
283    /// - `images`: List of input image references
284    ///
285    /// # Returns
286    /// List of recognition results
287    pub fn recognize_batch_ref(
288        &self,
289        images: &[&DynamicImage],
290    ) -> OcrResult<Vec<RecognitionResult>> {
291        if images.is_empty() {
292            return Ok(Vec::new());
293        }
294
295        // For small number of images, process individually
296        if images.len() <= 2 || !self.options.enable_batch {
297            return images.iter().map(|img| self.recognize(img)).collect();
298        }
299
300        // Batch processing
301        let mut results = Vec::with_capacity(images.len());
302
303        for chunk in images.chunks(self.options.batch_size) {
304            // Dereference and convert to Vec<DynamicImage>
305            let chunk_owned: Vec<DynamicImage> = chunk.iter().map(|img| (*img).clone()).collect();
306            let batch_results = self.recognize_batch_internal(&chunk_owned)?;
307            results.extend(batch_results);
308        }
309
310        Ok(results)
311    }
312
313    /// Internal batch recognition
314    fn recognize_batch_internal(
315        &self,
316        images: &[DynamicImage],
317    ) -> OcrResult<Vec<RecognitionResult>> {
318        if images.is_empty() {
319            return Ok(Vec::new());
320        }
321
322        // If only one image, process individually
323        if images.len() == 1 {
324            return Ok(vec![self.recognize(&images[0])?]);
325        }
326
327        // Batch preprocessing
328        let batch_input = crate::preprocess::preprocess_batch_for_rec(
329            images,
330            self.options.target_height,
331            &self.normalize_params,
332        );
333
334        // Batch inference
335        let batch_output = self.engine.run_dynamic(batch_input.view().into_dyn())?;
336
337        // Decode output for each sample
338        let shape = batch_output.shape();
339        if shape.len() != 3 {
340            return Err(OcrError::PostprocessError(format!(
341                "Batch inference output shape error: {:?}",
342                shape
343            )));
344        }
345
346        let batch_size = shape[0];
347        let mut results = Vec::with_capacity(batch_size);
348
349        for i in 0..batch_size {
350            // Extract output for single sample
351            let sample_output = batch_output.slice(ndarray::s![i, .., ..]).to_owned();
352            let sample_output_dyn = sample_output.into_dyn();
353            let result = self.decode_output(&sample_output_dyn)?;
354            results.push(result);
355        }
356
357        Ok(results)
358    }
359
360    /// Decode model output
361    fn decode_output(&self, output: &ArrayD<f32>) -> OcrResult<RecognitionResult> {
362        let shape = output.shape();
363
364        // Output shape should be [batch, seq_len, num_classes] or [seq_len, num_classes]
365        let (seq_len, num_classes) = if shape.len() == 3 {
366            (shape[1], shape[2])
367        } else if shape.len() == 2 {
368            (shape[0], shape[1])
369        } else {
370            return Err(OcrError::PostprocessError(format!(
371                "Invalid output shape: {:?}",
372                shape
373            )));
374        };
375
376        let output_data: Vec<f32> = output.iter().cloned().collect();
377
378        // CTC decoding
379        let mut char_scores = Vec::new();
380        let mut prev_idx = 0usize;
381
382        for t in 0..seq_len {
383            // Find character with maximum probability at current time step
384            let start = t * num_classes;
385            let end = start + num_classes;
386            let probs = &output_data[start..end];
387
388            let (max_idx, &max_prob) = probs
389                .iter()
390                .enumerate()
391                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
392                .unwrap();
393
394            // CTC decoding rule: skip blank (index 0) and duplicate characters
395            if max_idx != 0 && max_idx != prev_idx {
396                if max_idx < self.charset.len() {
397                    let ch = self.charset[max_idx];
398
399                    // Use raw logit value as confidence (model output is already softmax probability)
400                    // For large character sets, softmax scores can be very small, so use max_prob directly
401                    let score = max_prob;
402
403                    // Only filter out very low confidence characters
404                    let threshold = if Self::is_punctuation(ch) {
405                        self.options.punct_min_score
406                    } else {
407                        self.options.min_score
408                    };
409
410                    if score >= threshold {
411                        char_scores.push((ch, score));
412                    }
413                }
414            }
415
416            prev_idx = max_idx;
417        }
418
419        // Calculate average confidence
420        let confidence = if char_scores.is_empty() {
421            0.0
422        } else {
423            char_scores.iter().map(|(_, s)| s).sum::<f32>() / char_scores.len() as f32
424        };
425
426        // Extract text
427        let text: String = char_scores.iter().map(|(ch, _)| ch).collect();
428
429        Ok(RecognitionResult::new(text, confidence, char_scores))
430    }
431
432    /// Check if character is punctuation
433    fn is_punctuation(ch: char) -> bool {
434        PUNCTUATIONS.contains(&ch)
435    }
436}
437
438/// Low-level recognition API
439impl RecModel {
440    /// Raw inference interface
441    ///
442    /// Execute model inference directly without preprocessing and postprocessing
443    ///
444    /// # Parameters
445    /// - `input`: Preprocessed input tensor [1, 3, H, W]
446    ///
447    /// # Returns
448    /// Model raw output
449    pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
450        Ok(self.engine.run_dynamic(input)?)
451    }
452
453    /// Get model input shape
454    pub fn input_shape(&self) -> &[usize] {
455        self.engine.input_shape()
456    }
457
458    /// Get model output shape
459    pub fn output_shape(&self) -> &[usize] {
460        self.engine.output_shape()
461    }
462
463    /// Get charset
464    pub fn charset(&self) -> &[char] {
465        &self.charset
466    }
467
468    /// Get character by index
469    pub fn get_char(&self, index: usize) -> Option<char> {
470        self.charset.get(index).copied()
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477
478    #[test]
479    fn test_rec_options_default() {
480        let opts = RecOptions::default();
481        assert_eq!(opts.target_height, 48);
482        assert_eq!(opts.min_score, 0.3);
483        assert_eq!(opts.punct_min_score, 0.1);
484        assert_eq!(opts.batch_size, 8);
485        assert!(opts.enable_batch);
486    }
487
488    #[test]
489    fn test_rec_options_builder() {
490        let opts = RecOptions::new()
491            .with_target_height(32)
492            .with_min_score(0.6)
493            .with_punct_min_score(0.2)
494            .with_batch_size(16)
495            .with_batch(false);
496
497        assert_eq!(opts.target_height, 32);
498        assert_eq!(opts.min_score, 0.6);
499        assert_eq!(opts.punct_min_score, 0.2);
500        assert_eq!(opts.batch_size, 16);
501        assert!(!opts.enable_batch);
502    }
503
504    #[test]
505    fn test_recognition_result_new() {
506        let char_scores = vec![
507            ('H', 0.99),
508            ('e', 0.94),
509            ('l', 0.93),
510            ('l', 0.95),
511            ('o', 0.94),
512        ];
513        let result = RecognitionResult::new("Hello".to_string(), 0.95, char_scores.clone());
514
515        assert_eq!(result.text, "Hello");
516        assert_eq!(result.confidence, 0.95);
517        assert_eq!(result.char_scores.len(), 5);
518        assert_eq!(result.char_scores[0].0, 'H');
519        assert_eq!(result.char_scores[0].1, 0.99);
520    }
521
522    #[test]
523    fn test_recognition_result_is_valid() {
524        let result = RecognitionResult::new(
525            "Hello".to_string(),
526            0.95,
527            vec![
528                ('H', 0.99),
529                ('e', 0.94),
530                ('l', 0.93),
531                ('l', 0.95),
532                ('o', 0.94),
533            ],
534        );
535
536        assert!(result.is_valid(0.9));
537        assert!(result.is_valid(0.95));
538        assert!(!result.is_valid(0.96));
539        assert!(!result.is_valid(0.99));
540    }
541
542    #[test]
543    fn test_recognition_result_empty() {
544        let result = RecognitionResult::new(String::new(), 0.0, vec![]);
545
546        assert!(result.text.is_empty());
547        assert_eq!(result.confidence, 0.0);
548        assert!(!result.is_valid(0.1));
549    }
550
551    #[test]
552    fn test_is_punctuation_common() {
553        // English punctuation
554        assert!(RecModel::is_punctuation(','));
555        assert!(RecModel::is_punctuation('.'));
556        assert!(RecModel::is_punctuation('!'));
557        assert!(RecModel::is_punctuation('?'));
558        assert!(RecModel::is_punctuation(';'));
559        assert!(RecModel::is_punctuation(':'));
560        assert!(RecModel::is_punctuation('"'));
561        assert!(RecModel::is_punctuation('\''));
562    }
563
564    #[test]
565    fn test_is_punctuation_chinese() {
566        // Chinese punctuation
567        assert!(RecModel::is_punctuation(','));
568        assert!(RecModel::is_punctuation('。'));
569        assert!(RecModel::is_punctuation('!'));
570        assert!(RecModel::is_punctuation('?'));
571        assert!(RecModel::is_punctuation(';'));
572        assert!(RecModel::is_punctuation(':'));
573        assert!(RecModel::is_punctuation('、'));
574        assert!(RecModel::is_punctuation('—'));
575        assert!(RecModel::is_punctuation('…'));
576    }
577
578    #[test]
579    fn test_is_punctuation_brackets() {
580        assert!(RecModel::is_punctuation('('));
581        assert!(RecModel::is_punctuation(')'));
582        assert!(RecModel::is_punctuation('['));
583        assert!(RecModel::is_punctuation(']'));
584        assert!(RecModel::is_punctuation('{'));
585        assert!(RecModel::is_punctuation('}'));
586        assert!(RecModel::is_punctuation('「'));
587        assert!(RecModel::is_punctuation('」'));
588        assert!(RecModel::is_punctuation('《'));
589        assert!(RecModel::is_punctuation('》'));
590    }
591
592    #[test]
593    fn test_is_punctuation_false() {
594        // Non-punctuation characters
595        assert!(!RecModel::is_punctuation('A'));
596        assert!(!RecModel::is_punctuation('z'));
597        assert!(!RecModel::is_punctuation('0'));
598        assert!(!RecModel::is_punctuation('中'));
599        assert!(!RecModel::is_punctuation('文'));
600        assert!(!RecModel::is_punctuation(' '));
601    }
602}