1use image::DynamicImage;
6use ndarray::ArrayD;
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::preprocess::{preprocess_for_rec, NormalizeParams};
12
13#[derive(Debug, Clone)]
15pub struct RecognitionResult {
16 pub text: String,
18 pub confidence: f32,
20 pub char_scores: Vec<(char, f32)>,
22}
23
24impl RecognitionResult {
25 pub fn new(text: String, confidence: f32, char_scores: Vec<(char, f32)>) -> Self {
27 Self {
28 text,
29 confidence,
30 char_scores,
31 }
32 }
33
34 pub fn is_valid(&self, threshold: f32) -> bool {
36 self.confidence >= threshold
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct RecOptions {
43 pub target_height: u32,
45 pub min_score: f32,
47 pub punct_min_score: f32,
49 pub batch_size: usize,
51 pub enable_batch: bool,
53}
54
55impl Default for RecOptions {
56 fn default() -> Self {
57 Self {
58 target_height: 48,
59 min_score: 0.3, punct_min_score: 0.1,
61 batch_size: 8,
62 enable_batch: true,
63 }
64 }
65}
66
67impl RecOptions {
68 pub fn new() -> Self {
70 Self::default()
71 }
72
73 pub fn with_target_height(mut self, height: u32) -> Self {
75 self.target_height = height;
76 self
77 }
78
79 pub fn with_min_score(mut self, score: f32) -> Self {
81 self.min_score = score;
82 self
83 }
84
85 pub fn with_punct_min_score(mut self, score: f32) -> Self {
87 self.punct_min_score = score;
88 self
89 }
90
91 pub fn with_batch_size(mut self, size: usize) -> Self {
93 self.batch_size = size;
94 self
95 }
96
97 pub fn with_batch(mut self, enable: bool) -> Self {
99 self.enable_batch = enable;
100 self
101 }
102}
103
104pub struct RecModel {
106 engine: InferenceEngine,
107 charset: Vec<char>,
109 options: RecOptions,
110 normalize_params: NormalizeParams,
111}
112
113const PUNCTUATIONS: [char; 49] = [
115 ',', '.', '!', '?', ';', ':', '"', '\'', '(', ')', '[', ']', '{', '}', '-', '_', '/', '\\',
116 '|', '@', '#', '$', '%', '&', '*', '+', '=', '~', ',', '。', '!', '?', ';', ':', '、',
117 '「', '」', '『', '』', '(', ')', '【', '】', '《', '》', '—', '…', '·', '~',
118];
119
120impl RecModel {
121 pub fn from_file(
128 model_path: impl AsRef<Path>,
129 charset_path: impl AsRef<Path>,
130 config: Option<InferenceConfig>,
131 ) -> OcrResult<Self> {
132 let engine = InferenceEngine::from_file(model_path, config)?;
133 let charset = Self::load_charset_from_file(charset_path)?;
134
135 Ok(Self {
136 engine,
137 charset,
138 options: RecOptions::default(),
139 normalize_params: NormalizeParams::paddle_rec(),
140 })
141 }
142
143 pub fn from_bytes(
145 model_bytes: &[u8],
146 charset_path: impl AsRef<Path>,
147 config: Option<InferenceConfig>,
148 ) -> OcrResult<Self> {
149 let engine = InferenceEngine::from_buffer(model_bytes, config)?;
150 let charset = Self::load_charset_from_file(charset_path)?;
151
152 Ok(Self {
153 engine,
154 charset,
155 options: RecOptions::default(),
156 normalize_params: NormalizeParams::paddle_rec(),
157 })
158 }
159
160 pub fn from_bytes_with_charset(
162 model_bytes: &[u8],
163 charset_bytes: &[u8],
164 config: Option<InferenceConfig>,
165 ) -> OcrResult<Self> {
166 let engine = InferenceEngine::from_buffer(model_bytes, config)?;
167 let charset = Self::parse_charset(charset_bytes)?;
168
169 Ok(Self {
170 engine,
171 charset,
172 options: RecOptions::default(),
173 normalize_params: NormalizeParams::paddle_rec(),
174 })
175 }
176
177 fn load_charset_from_file(path: impl AsRef<Path>) -> OcrResult<Vec<char>> {
179 let content = std::fs::read_to_string(path)?;
180 Self::parse_charset(content.as_bytes())
181 }
182
183 fn parse_charset(data: &[u8]) -> OcrResult<Vec<char>> {
185 let content = std::str::from_utf8(data)
186 .map_err(|e| OcrError::CharsetError(format!("UTF-8 decode error: {}", e)))?;
187
188 let mut charset: Vec<char> = vec![' ']; for ch in content.chars() {
193 if ch != '\n' && ch != '\r' {
194 charset.push(ch);
195 }
196 }
197
198 charset.push(' '); if charset.len() < 3 {
201 return Err(OcrError::CharsetError("Charset too small".to_string()));
202 }
203
204 Ok(charset)
205 }
206
207 pub fn with_options(mut self, options: RecOptions) -> Self {
209 self.options = options;
210 self
211 }
212
213 pub fn options(&self) -> &RecOptions {
215 &self.options
216 }
217
218 pub fn options_mut(&mut self) -> &mut RecOptions {
220 &mut self.options
221 }
222
223 pub fn charset_size(&self) -> usize {
225 self.charset.len()
226 }
227
228 pub fn recognize(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
236 let input = preprocess_for_rec(image, self.options.target_height, &self.normalize_params)?;
238
239 let output = self.engine.run_dynamic(input.view().into_dyn())?;
241
242 self.decode_output(&output)
244 }
245
246 pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<String> {
248 let result = self.recognize(image)?;
249 Ok(result.text)
250 }
251
252 pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
260 if images.is_empty() {
261 return Ok(Vec::new());
262 }
263
264 if images.len() <= 2 || !self.options.enable_batch {
266 return images.iter().map(|img| self.recognize(img)).collect();
267 }
268
269 let mut results = Vec::with_capacity(images.len());
271
272 for chunk in images.chunks(self.options.batch_size) {
273 let batch_results = self.recognize_batch_internal(chunk)?;
274 results.extend(batch_results);
275 }
276
277 Ok(results)
278 }
279
280 pub fn recognize_batch_ref(
288 &self,
289 images: &[&DynamicImage],
290 ) -> OcrResult<Vec<RecognitionResult>> {
291 if images.is_empty() {
292 return Ok(Vec::new());
293 }
294
295 if images.len() <= 2 || !self.options.enable_batch {
297 return images.iter().map(|img| self.recognize(img)).collect();
298 }
299
300 let mut results = Vec::with_capacity(images.len());
302
303 for chunk in images.chunks(self.options.batch_size) {
304 let chunk_owned: Vec<DynamicImage> = chunk.iter().map(|img| (*img).clone()).collect();
306 let batch_results = self.recognize_batch_internal(&chunk_owned)?;
307 results.extend(batch_results);
308 }
309
310 Ok(results)
311 }
312
313 fn recognize_batch_internal(
315 &self,
316 images: &[DynamicImage],
317 ) -> OcrResult<Vec<RecognitionResult>> {
318 if images.is_empty() {
319 return Ok(Vec::new());
320 }
321
322 if images.len() == 1 {
324 return Ok(vec![self.recognize(&images[0])?]);
325 }
326
327 let batch_input = crate::preprocess::preprocess_batch_for_rec(
329 images,
330 self.options.target_height,
331 &self.normalize_params,
332 )?;
333
334 let batch_output = self.engine.run_dynamic(batch_input.view().into_dyn())?;
336
337 let shape = batch_output.shape();
339 if shape.len() != 3 {
340 return Err(OcrError::PostprocessError(format!(
341 "Batch inference output shape error: {:?}",
342 shape
343 )));
344 }
345
346 let batch_size = shape[0];
347 let mut results = Vec::with_capacity(batch_size);
348
349 for i in 0..batch_size {
350 let sample_output = batch_output.slice(ndarray::s![i, .., ..]).to_owned();
352 let sample_output_dyn = sample_output.into_dyn();
353 let result = self.decode_output(&sample_output_dyn)?;
354 results.push(result);
355 }
356
357 Ok(results)
358 }
359
360 fn decode_output(&self, output: &ArrayD<f32>) -> OcrResult<RecognitionResult> {
362 let shape = output.shape();
363
364 let (seq_len, num_classes) = if shape.len() == 3 {
366 (shape[1], shape[2])
367 } else if shape.len() == 2 {
368 (shape[0], shape[1])
369 } else {
370 return Err(OcrError::PostprocessError(format!(
371 "Invalid output shape: {:?}",
372 shape
373 )));
374 };
375
376 let output_data: Vec<f32> = output.iter().cloned().collect();
377
378 let mut char_scores = Vec::new();
380 let mut prev_idx = 0usize;
381
382 for t in 0..seq_len {
383 let start = t * num_classes;
385 let end = start + num_classes;
386 let probs = &output_data[start..end];
387
388 let (max_idx, &max_prob) = probs
389 .iter()
390 .enumerate()
391 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
392 .ok_or_else(|| {
393 OcrError::PostprocessError("Empty probability slice in CTC decoding".into())
394 })?;
395
396 if max_idx != 0 && max_idx != prev_idx {
398 if max_idx < self.charset.len() {
399 let ch = self.charset[max_idx];
400
401 let score = max_prob;
404
405 let threshold = if Self::is_punctuation(ch) {
407 self.options.punct_min_score
408 } else {
409 self.options.min_score
410 };
411
412 if score >= threshold {
413 char_scores.push((ch, score));
414 }
415 }
416 }
417
418 prev_idx = max_idx;
419 }
420
421 let confidence = if char_scores.is_empty() {
423 0.0
424 } else {
425 char_scores.iter().map(|(_, s)| s).sum::<f32>() / char_scores.len() as f32
426 };
427
428 let text: String = char_scores.iter().map(|(ch, _)| ch).collect();
430
431 Ok(RecognitionResult::new(text, confidence, char_scores))
432 }
433
434 fn is_punctuation(ch: char) -> bool {
436 PUNCTUATIONS.contains(&ch)
437 }
438}
439
440impl RecModel {
442 pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
452 Ok(self.engine.run_dynamic(input)?)
453 }
454
455 pub fn input_shape(&self) -> &[usize] {
457 self.engine.input_shape()
458 }
459
460 pub fn output_shape(&self) -> &[usize] {
462 self.engine.output_shape()
463 }
464
465 pub fn charset(&self) -> &[char] {
467 &self.charset
468 }
469
470 pub fn get_char(&self, index: usize) -> Option<char> {
472 self.charset.get(index).copied()
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_rec_options_default() {
482 let opts = RecOptions::default();
483 assert_eq!(opts.target_height, 48);
484 assert_eq!(opts.min_score, 0.3);
485 assert_eq!(opts.punct_min_score, 0.1);
486 assert_eq!(opts.batch_size, 8);
487 assert!(opts.enable_batch);
488 }
489
490 #[test]
491 fn test_rec_options_builder() {
492 let opts = RecOptions::new()
493 .with_target_height(32)
494 .with_min_score(0.6)
495 .with_punct_min_score(0.2)
496 .with_batch_size(16)
497 .with_batch(false);
498
499 assert_eq!(opts.target_height, 32);
500 assert_eq!(opts.min_score, 0.6);
501 assert_eq!(opts.punct_min_score, 0.2);
502 assert_eq!(opts.batch_size, 16);
503 assert!(!opts.enable_batch);
504 }
505
506 #[test]
507 fn test_recognition_result_new() {
508 let char_scores = vec![
509 ('H', 0.99),
510 ('e', 0.94),
511 ('l', 0.93),
512 ('l', 0.95),
513 ('o', 0.94),
514 ];
515 let result = RecognitionResult::new("Hello".to_string(), 0.95, char_scores.clone());
516
517 assert_eq!(result.text, "Hello");
518 assert_eq!(result.confidence, 0.95);
519 assert_eq!(result.char_scores.len(), 5);
520 assert_eq!(result.char_scores[0].0, 'H');
521 assert_eq!(result.char_scores[0].1, 0.99);
522 }
523
524 #[test]
525 fn test_recognition_result_is_valid() {
526 let result = RecognitionResult::new(
527 "Hello".to_string(),
528 0.95,
529 vec![
530 ('H', 0.99),
531 ('e', 0.94),
532 ('l', 0.93),
533 ('l', 0.95),
534 ('o', 0.94),
535 ],
536 );
537
538 assert!(result.is_valid(0.9));
539 assert!(result.is_valid(0.95));
540 assert!(!result.is_valid(0.96));
541 assert!(!result.is_valid(0.99));
542 }
543
544 #[test]
545 fn test_recognition_result_empty() {
546 let result = RecognitionResult::new(String::new(), 0.0, vec![]);
547
548 assert!(result.text.is_empty());
549 assert_eq!(result.confidence, 0.0);
550 assert!(!result.is_valid(0.1));
551 }
552
553 #[test]
554 fn test_is_punctuation_common() {
555 assert!(RecModel::is_punctuation(','));
557 assert!(RecModel::is_punctuation('.'));
558 assert!(RecModel::is_punctuation('!'));
559 assert!(RecModel::is_punctuation('?'));
560 assert!(RecModel::is_punctuation(';'));
561 assert!(RecModel::is_punctuation(':'));
562 assert!(RecModel::is_punctuation('"'));
563 assert!(RecModel::is_punctuation('\''));
564 }
565
566 #[test]
567 fn test_is_punctuation_chinese() {
568 assert!(RecModel::is_punctuation(','));
570 assert!(RecModel::is_punctuation('。'));
571 assert!(RecModel::is_punctuation('!'));
572 assert!(RecModel::is_punctuation('?'));
573 assert!(RecModel::is_punctuation(';'));
574 assert!(RecModel::is_punctuation(':'));
575 assert!(RecModel::is_punctuation('、'));
576 assert!(RecModel::is_punctuation('—'));
577 assert!(RecModel::is_punctuation('…'));
578 }
579
580 #[test]
581 fn test_is_punctuation_brackets() {
582 assert!(RecModel::is_punctuation('('));
583 assert!(RecModel::is_punctuation(')'));
584 assert!(RecModel::is_punctuation('['));
585 assert!(RecModel::is_punctuation(']'));
586 assert!(RecModel::is_punctuation('{'));
587 assert!(RecModel::is_punctuation('}'));
588 assert!(RecModel::is_punctuation('「'));
589 assert!(RecModel::is_punctuation('」'));
590 assert!(RecModel::is_punctuation('《'));
591 assert!(RecModel::is_punctuation('》'));
592 }
593
594 #[test]
595 fn test_is_punctuation_false() {
596 assert!(!RecModel::is_punctuation('A'));
598 assert!(!RecModel::is_punctuation('z'));
599 assert!(!RecModel::is_punctuation('0'));
600 assert!(!RecModel::is_punctuation('中'));
601 assert!(!RecModel::is_punctuation('文'));
602 assert!(!RecModel::is_punctuation(' '));
603 }
604}