1use image::DynamicImage;
6use ndarray::ArrayD;
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::preprocess::{preprocess_for_rec, NormalizeParams};
12
13#[derive(Debug, Clone)]
15pub struct RecognitionResult {
16 pub text: String,
18 pub confidence: f32,
20 pub char_scores: Vec<(char, f32)>,
22}
23
24impl RecognitionResult {
25 pub fn new(text: String, confidence: f32, char_scores: Vec<(char, f32)>) -> Self {
27 Self {
28 text,
29 confidence,
30 char_scores,
31 }
32 }
33
34 pub fn is_valid(&self, threshold: f32) -> bool {
36 self.confidence >= threshold
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct RecOptions {
43 pub target_height: u32,
45 pub min_score: f32,
47 pub punct_min_score: f32,
49 pub batch_size: usize,
51 pub enable_batch: bool,
53}
54
55impl Default for RecOptions {
56 fn default() -> Self {
57 Self {
58 target_height: 48,
59 min_score: 0.3, punct_min_score: 0.1,
61 batch_size: 8,
62 enable_batch: true,
63 }
64 }
65}
66
67impl RecOptions {
68 pub fn new() -> Self {
70 Self::default()
71 }
72
73 pub fn with_target_height(mut self, height: u32) -> Self {
75 self.target_height = height;
76 self
77 }
78
79 pub fn with_min_score(mut self, score: f32) -> Self {
81 self.min_score = score;
82 self
83 }
84
85 pub fn with_punct_min_score(mut self, score: f32) -> Self {
87 self.punct_min_score = score;
88 self
89 }
90
91 pub fn with_batch_size(mut self, size: usize) -> Self {
93 self.batch_size = size;
94 self
95 }
96
97 pub fn with_batch(mut self, enable: bool) -> Self {
99 self.enable_batch = enable;
100 self
101 }
102}
103
104pub struct RecModel {
106 engine: InferenceEngine,
107 charset: Vec<char>,
109 options: RecOptions,
110 normalize_params: NormalizeParams,
111}
112
113const PUNCTUATIONS: [char; 49] = [
115 ',', '.', '!', '?', ';', ':', '"', '\'', '(', ')', '[', ']', '{', '}', '-', '_', '/', '\\',
116 '|', '@', '#', '$', '%', '&', '*', '+', '=', '~', ',', '。', '!', '?', ';', ':', '、',
117 '「', '」', '『', '』', '(', ')', '【', '】', '《', '》', '—', '…', '·', '~',
118];
119
120impl RecModel {
121 pub fn from_file(
128 model_path: impl AsRef<Path>,
129 charset_path: impl AsRef<Path>,
130 config: Option<InferenceConfig>,
131 ) -> OcrResult<Self> {
132 let engine = InferenceEngine::from_file(model_path, config)?;
133 let charset = Self::load_charset_from_file(charset_path)?;
134
135 Ok(Self {
136 engine,
137 charset,
138 options: RecOptions::default(),
139 normalize_params: NormalizeParams::paddle_rec(),
140 })
141 }
142
143 pub fn from_bytes(
145 model_bytes: &[u8],
146 charset_path: impl AsRef<Path>,
147 config: Option<InferenceConfig>,
148 ) -> OcrResult<Self> {
149 let engine = InferenceEngine::from_buffer(model_bytes, config)?;
150 let charset = Self::load_charset_from_file(charset_path)?;
151
152 Ok(Self {
153 engine,
154 charset,
155 options: RecOptions::default(),
156 normalize_params: NormalizeParams::paddle_rec(),
157 })
158 }
159
160 pub fn from_bytes_with_charset(
162 model_bytes: &[u8],
163 charset_bytes: &[u8],
164 config: Option<InferenceConfig>,
165 ) -> OcrResult<Self> {
166 let engine = InferenceEngine::from_buffer(model_bytes, config)?;
167 let charset = Self::parse_charset(charset_bytes)?;
168
169 Ok(Self {
170 engine,
171 charset,
172 options: RecOptions::default(),
173 normalize_params: NormalizeParams::paddle_rec(),
174 })
175 }
176
177 fn load_charset_from_file(path: impl AsRef<Path>) -> OcrResult<Vec<char>> {
179 let content = std::fs::read_to_string(path)?;
180 Self::parse_charset(content.as_bytes())
181 }
182
183 fn parse_charset(data: &[u8]) -> OcrResult<Vec<char>> {
185 let content = std::str::from_utf8(data)
186 .map_err(|e| OcrError::CharsetError(format!("UTF-8 decode error: {}", e)))?;
187
188 let mut charset: Vec<char> = vec![' ']; for ch in content.chars() {
193 if ch != '\n' && ch != '\r' {
194 charset.push(ch);
195 }
196 }
197
198 charset.push(' '); if charset.len() < 3 {
201 return Err(OcrError::CharsetError("Charset too small".to_string()));
202 }
203
204 Ok(charset)
205 }
206
207 pub fn with_options(mut self, options: RecOptions) -> Self {
209 self.options = options;
210 self
211 }
212
213 pub fn options(&self) -> &RecOptions {
215 &self.options
216 }
217
218 pub fn options_mut(&mut self) -> &mut RecOptions {
220 &mut self.options
221 }
222
223 pub fn charset_size(&self) -> usize {
225 self.charset.len()
226 }
227
228 pub fn recognize(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
236 let input = preprocess_for_rec(image, self.options.target_height, &self.normalize_params);
238
239 let output = self.engine.run_dynamic(input.view().into_dyn())?;
241
242 self.decode_output(&output)
244 }
245
246 pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<String> {
248 let result = self.recognize(image)?;
249 Ok(result.text)
250 }
251
252 pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
260 if images.is_empty() {
261 return Ok(Vec::new());
262 }
263
264 if images.len() <= 2 || !self.options.enable_batch {
266 return images.iter().map(|img| self.recognize(img)).collect();
267 }
268
269 let mut results = Vec::with_capacity(images.len());
271
272 for chunk in images.chunks(self.options.batch_size) {
273 let batch_results = self.recognize_batch_internal(chunk)?;
274 results.extend(batch_results);
275 }
276
277 Ok(results)
278 }
279
280 pub fn recognize_batch_ref(
288 &self,
289 images: &[&DynamicImage],
290 ) -> OcrResult<Vec<RecognitionResult>> {
291 if images.is_empty() {
292 return Ok(Vec::new());
293 }
294
295 if images.len() <= 2 || !self.options.enable_batch {
297 return images.iter().map(|img| self.recognize(img)).collect();
298 }
299
300 let mut results = Vec::with_capacity(images.len());
302
303 for chunk in images.chunks(self.options.batch_size) {
304 let chunk_owned: Vec<DynamicImage> = chunk.iter().map(|img| (*img).clone()).collect();
306 let batch_results = self.recognize_batch_internal(&chunk_owned)?;
307 results.extend(batch_results);
308 }
309
310 Ok(results)
311 }
312
313 fn recognize_batch_internal(
315 &self,
316 images: &[DynamicImage],
317 ) -> OcrResult<Vec<RecognitionResult>> {
318 if images.is_empty() {
319 return Ok(Vec::new());
320 }
321
322 if images.len() == 1 {
324 return Ok(vec![self.recognize(&images[0])?]);
325 }
326
327 let batch_input = crate::preprocess::preprocess_batch_for_rec(
329 images,
330 self.options.target_height,
331 &self.normalize_params,
332 );
333
334 let batch_output = self.engine.run_dynamic(batch_input.view().into_dyn())?;
336
337 let shape = batch_output.shape();
339 if shape.len() != 3 {
340 return Err(OcrError::PostprocessError(format!(
341 "Batch inference output shape error: {:?}",
342 shape
343 )));
344 }
345
346 let batch_size = shape[0];
347 let mut results = Vec::with_capacity(batch_size);
348
349 for i in 0..batch_size {
350 let sample_output = batch_output.slice(ndarray::s![i, .., ..]).to_owned();
352 let sample_output_dyn = sample_output.into_dyn();
353 let result = self.decode_output(&sample_output_dyn)?;
354 results.push(result);
355 }
356
357 Ok(results)
358 }
359
360 fn decode_output(&self, output: &ArrayD<f32>) -> OcrResult<RecognitionResult> {
362 let shape = output.shape();
363
364 let (seq_len, num_classes) = if shape.len() == 3 {
366 (shape[1], shape[2])
367 } else if shape.len() == 2 {
368 (shape[0], shape[1])
369 } else {
370 return Err(OcrError::PostprocessError(format!(
371 "Invalid output shape: {:?}",
372 shape
373 )));
374 };
375
376 let output_data: Vec<f32> = output.iter().cloned().collect();
377
378 let mut char_scores = Vec::new();
380 let mut prev_idx = 0usize;
381
382 for t in 0..seq_len {
383 let start = t * num_classes;
385 let end = start + num_classes;
386 let probs = &output_data[start..end];
387
388 let (max_idx, &max_prob) = probs
389 .iter()
390 .enumerate()
391 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
392 .unwrap();
393
394 if max_idx != 0 && max_idx != prev_idx {
396 if max_idx < self.charset.len() {
397 let ch = self.charset[max_idx];
398
399 let score = max_prob;
402
403 let threshold = if Self::is_punctuation(ch) {
405 self.options.punct_min_score
406 } else {
407 self.options.min_score
408 };
409
410 if score >= threshold {
411 char_scores.push((ch, score));
412 }
413 }
414 }
415
416 prev_idx = max_idx;
417 }
418
419 let confidence = if char_scores.is_empty() {
421 0.0
422 } else {
423 char_scores.iter().map(|(_, s)| s).sum::<f32>() / char_scores.len() as f32
424 };
425
426 let text: String = char_scores.iter().map(|(ch, _)| ch).collect();
428
429 Ok(RecognitionResult::new(text, confidence, char_scores))
430 }
431
432 fn is_punctuation(ch: char) -> bool {
434 PUNCTUATIONS.contains(&ch)
435 }
436}
437
438impl RecModel {
440 pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
450 Ok(self.engine.run_dynamic(input)?)
451 }
452
453 pub fn input_shape(&self) -> &[usize] {
455 self.engine.input_shape()
456 }
457
458 pub fn output_shape(&self) -> &[usize] {
460 self.engine.output_shape()
461 }
462
463 pub fn charset(&self) -> &[char] {
465 &self.charset
466 }
467
468 pub fn get_char(&self, index: usize) -> Option<char> {
470 self.charset.get(index).copied()
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477
478 #[test]
479 fn test_rec_options_default() {
480 let opts = RecOptions::default();
481 assert_eq!(opts.target_height, 48);
482 assert_eq!(opts.min_score, 0.3);
483 assert_eq!(opts.punct_min_score, 0.1);
484 assert_eq!(opts.batch_size, 8);
485 assert!(opts.enable_batch);
486 }
487
488 #[test]
489 fn test_rec_options_builder() {
490 let opts = RecOptions::new()
491 .with_target_height(32)
492 .with_min_score(0.6)
493 .with_punct_min_score(0.2)
494 .with_batch_size(16)
495 .with_batch(false);
496
497 assert_eq!(opts.target_height, 32);
498 assert_eq!(opts.min_score, 0.6);
499 assert_eq!(opts.punct_min_score, 0.2);
500 assert_eq!(opts.batch_size, 16);
501 assert!(!opts.enable_batch);
502 }
503
504 #[test]
505 fn test_recognition_result_new() {
506 let char_scores = vec![
507 ('H', 0.99),
508 ('e', 0.94),
509 ('l', 0.93),
510 ('l', 0.95),
511 ('o', 0.94),
512 ];
513 let result = RecognitionResult::new("Hello".to_string(), 0.95, char_scores.clone());
514
515 assert_eq!(result.text, "Hello");
516 assert_eq!(result.confidence, 0.95);
517 assert_eq!(result.char_scores.len(), 5);
518 assert_eq!(result.char_scores[0].0, 'H');
519 assert_eq!(result.char_scores[0].1, 0.99);
520 }
521
522 #[test]
523 fn test_recognition_result_is_valid() {
524 let result = RecognitionResult::new(
525 "Hello".to_string(),
526 0.95,
527 vec![
528 ('H', 0.99),
529 ('e', 0.94),
530 ('l', 0.93),
531 ('l', 0.95),
532 ('o', 0.94),
533 ],
534 );
535
536 assert!(result.is_valid(0.9));
537 assert!(result.is_valid(0.95));
538 assert!(!result.is_valid(0.96));
539 assert!(!result.is_valid(0.99));
540 }
541
542 #[test]
543 fn test_recognition_result_empty() {
544 let result = RecognitionResult::new(String::new(), 0.0, vec![]);
545
546 assert!(result.text.is_empty());
547 assert_eq!(result.confidence, 0.0);
548 assert!(!result.is_valid(0.1));
549 }
550
551 #[test]
552 fn test_is_punctuation_common() {
553 assert!(RecModel::is_punctuation(','));
555 assert!(RecModel::is_punctuation('.'));
556 assert!(RecModel::is_punctuation('!'));
557 assert!(RecModel::is_punctuation('?'));
558 assert!(RecModel::is_punctuation(';'));
559 assert!(RecModel::is_punctuation(':'));
560 assert!(RecModel::is_punctuation('"'));
561 assert!(RecModel::is_punctuation('\''));
562 }
563
564 #[test]
565 fn test_is_punctuation_chinese() {
566 assert!(RecModel::is_punctuation(','));
568 assert!(RecModel::is_punctuation('。'));
569 assert!(RecModel::is_punctuation('!'));
570 assert!(RecModel::is_punctuation('?'));
571 assert!(RecModel::is_punctuation(';'));
572 assert!(RecModel::is_punctuation(':'));
573 assert!(RecModel::is_punctuation('、'));
574 assert!(RecModel::is_punctuation('—'));
575 assert!(RecModel::is_punctuation('…'));
576 }
577
578 #[test]
579 fn test_is_punctuation_brackets() {
580 assert!(RecModel::is_punctuation('('));
581 assert!(RecModel::is_punctuation(')'));
582 assert!(RecModel::is_punctuation('['));
583 assert!(RecModel::is_punctuation(']'));
584 assert!(RecModel::is_punctuation('{'));
585 assert!(RecModel::is_punctuation('}'));
586 assert!(RecModel::is_punctuation('「'));
587 assert!(RecModel::is_punctuation('」'));
588 assert!(RecModel::is_punctuation('《'));
589 assert!(RecModel::is_punctuation('》'));
590 }
591
592 #[test]
593 fn test_is_punctuation_false() {
594 assert!(!RecModel::is_punctuation('A'));
596 assert!(!RecModel::is_punctuation('z'));
597 assert!(!RecModel::is_punctuation('0'));
598 assert!(!RecModel::is_punctuation('中'));
599 assert!(!RecModel::is_punctuation('文'));
600 assert!(!RecModel::is_punctuation(' '));
601 }
602}