use image::DynamicImage;
use ndarray::ArrayD;
use std::path::Path;
use crate::error::{OcrError, OcrResult};
use crate::mnn::{InferenceConfig, InferenceEngine};
use crate::preprocess::{preprocess_for_rec, NormalizeParams};
#[derive(Debug, Clone)]
pub struct RecognitionResult {
pub text: String,
pub confidence: f32,
pub char_scores: Vec<(char, f32)>,
}
impl RecognitionResult {
pub fn new(text: String, confidence: f32, char_scores: Vec<(char, f32)>) -> Self {
Self {
text,
confidence,
char_scores,
}
}
pub fn is_valid(&self, threshold: f32) -> bool {
self.confidence >= threshold
}
}
#[derive(Debug, Clone)]
pub struct RecOptions {
pub target_height: u32,
pub min_score: f32,
pub punct_min_score: f32,
pub batch_size: usize,
pub enable_batch: bool,
}
impl Default for RecOptions {
fn default() -> Self {
Self {
target_height: 48,
min_score: 0.3, punct_min_score: 0.1,
batch_size: 8,
enable_batch: true,
}
}
}
impl RecOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_target_height(mut self, height: u32) -> Self {
self.target_height = height;
self
}
pub fn with_min_score(mut self, score: f32) -> Self {
self.min_score = score;
self
}
pub fn with_punct_min_score(mut self, score: f32) -> Self {
self.punct_min_score = score;
self
}
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn with_batch(mut self, enable: bool) -> Self {
self.enable_batch = enable;
self
}
}
pub struct RecModel {
engine: InferenceEngine,
charset: Vec<char>,
options: RecOptions,
normalize_params: NormalizeParams,
}
const PUNCTUATIONS: [char; 49] = [
',', '.', '!', '?', ';', ':', '"', '\'', '(', ')', '[', ']', '{', '}', '-', '_', '/', '\\',
'|', '@', '#', '$', '%', '&', '*', '+', '=', '~', ',', '。', '!', '?', ';', ':', '、',
'「', '」', '『', '』', '(', ')', '【', '】', '《', '》', '—', '…', '·', '~',
];
impl RecModel {
pub fn from_file(
model_path: impl AsRef<Path>,
charset_path: impl AsRef<Path>,
config: Option<InferenceConfig>,
) -> OcrResult<Self> {
let engine = InferenceEngine::from_file(model_path, config)?;
let charset = Self::load_charset_from_file(charset_path)?;
Ok(Self {
engine,
charset,
options: RecOptions::default(),
normalize_params: NormalizeParams::paddle_rec(),
})
}
pub fn from_bytes(
model_bytes: &[u8],
charset_path: impl AsRef<Path>,
config: Option<InferenceConfig>,
) -> OcrResult<Self> {
let engine = InferenceEngine::from_buffer(model_bytes, config)?;
let charset = Self::load_charset_from_file(charset_path)?;
Ok(Self {
engine,
charset,
options: RecOptions::default(),
normalize_params: NormalizeParams::paddle_rec(),
})
}
pub fn from_bytes_with_charset(
model_bytes: &[u8],
charset_bytes: &[u8],
config: Option<InferenceConfig>,
) -> OcrResult<Self> {
let engine = InferenceEngine::from_buffer(model_bytes, config)?;
let charset = Self::parse_charset(charset_bytes)?;
Ok(Self {
engine,
charset,
options: RecOptions::default(),
normalize_params: NormalizeParams::paddle_rec(),
})
}
fn load_charset_from_file(path: impl AsRef<Path>) -> OcrResult<Vec<char>> {
let content = std::fs::read_to_string(path)?;
Self::parse_charset(content.as_bytes())
}
fn parse_charset(data: &[u8]) -> OcrResult<Vec<char>> {
let content = std::str::from_utf8(data)
.map_err(|e| OcrError::CharsetError(format!("UTF-8 decode error: {}", e)))?;
let mut charset: Vec<char> = vec![' '];
for ch in content.chars() {
if ch != '\n' && ch != '\r' {
charset.push(ch);
}
}
charset.push(' ');
if charset.len() < 3 {
return Err(OcrError::CharsetError("Charset too small".to_string()));
}
Ok(charset)
}
pub fn with_options(mut self, options: RecOptions) -> Self {
self.options = options;
self
}
pub fn options(&self) -> &RecOptions {
&self.options
}
pub fn options_mut(&mut self) -> &mut RecOptions {
&mut self.options
}
pub fn charset_size(&self) -> usize {
self.charset.len()
}
pub fn recognize(&self, image: &DynamicImage) -> OcrResult<RecognitionResult> {
let input = preprocess_for_rec(image, self.options.target_height, &self.normalize_params)?;
let output = self.engine.run_dynamic(input.view().into_dyn())?;
self.decode_output(&output)
}
pub fn recognize_text(&self, image: &DynamicImage) -> OcrResult<String> {
let result = self.recognize(image)?;
Ok(result.text)
}
pub fn recognize_batch(&self, images: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
if images.is_empty() {
return Ok(Vec::new());
}
if images.len() <= 2 || !self.options.enable_batch {
return images.iter().map(|img| self.recognize(img)).collect();
}
let mut results = Vec::with_capacity(images.len());
for chunk in images.chunks(self.options.batch_size) {
let batch_results = self.recognize_batch_internal(chunk)?;
results.extend(batch_results);
}
Ok(results)
}
pub fn recognize_batch_ref(
&self,
images: &[&DynamicImage],
) -> OcrResult<Vec<RecognitionResult>> {
if images.is_empty() {
return Ok(Vec::new());
}
if images.len() <= 2 || !self.options.enable_batch {
return images.iter().map(|img| self.recognize(img)).collect();
}
let mut results = Vec::with_capacity(images.len());
for chunk in images.chunks(self.options.batch_size) {
let chunk_owned: Vec<DynamicImage> = chunk.iter().map(|img| (*img).clone()).collect();
let batch_results = self.recognize_batch_internal(&chunk_owned)?;
results.extend(batch_results);
}
Ok(results)
}
fn recognize_batch_internal(
&self,
images: &[DynamicImage],
) -> OcrResult<Vec<RecognitionResult>> {
if images.is_empty() {
return Ok(Vec::new());
}
if images.len() == 1 {
return Ok(vec![self.recognize(&images[0])?]);
}
let batch_input = crate::preprocess::preprocess_batch_for_rec(
images,
self.options.target_height,
&self.normalize_params,
)?;
let batch_output = self.engine.run_dynamic(batch_input.view().into_dyn())?;
let shape = batch_output.shape();
if shape.len() != 3 {
return Err(OcrError::PostprocessError(format!(
"Batch inference output shape error: {:?}",
shape
)));
}
let batch_size = shape[0];
let mut results = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let sample_output = batch_output.slice(ndarray::s![i, .., ..]).to_owned();
let sample_output_dyn = sample_output.into_dyn();
let result = self.decode_output(&sample_output_dyn)?;
results.push(result);
}
Ok(results)
}
fn decode_output(&self, output: &ArrayD<f32>) -> OcrResult<RecognitionResult> {
let shape = output.shape();
let (seq_len, num_classes) = if shape.len() == 3 {
(shape[1], shape[2])
} else if shape.len() == 2 {
(shape[0], shape[1])
} else {
return Err(OcrError::PostprocessError(format!(
"Invalid output shape: {:?}",
shape
)));
};
let output_data: Vec<f32> = output.iter().cloned().collect();
let mut char_scores = Vec::new();
let mut prev_idx = 0usize;
for t in 0..seq_len {
let start = t * num_classes;
let end = start + num_classes;
let probs = &output_data[start..end];
let (max_idx, &max_prob) = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| {
OcrError::PostprocessError("Empty probability slice in CTC decoding".into())
})?;
if max_idx != 0 && max_idx != prev_idx {
if max_idx < self.charset.len() {
let ch = self.charset[max_idx];
let score = max_prob;
let threshold = if Self::is_punctuation(ch) {
self.options.punct_min_score
} else {
self.options.min_score
};
if score >= threshold {
char_scores.push((ch, score));
}
}
}
prev_idx = max_idx;
}
let confidence = if char_scores.is_empty() {
0.0
} else {
char_scores.iter().map(|(_, s)| s).sum::<f32>() / char_scores.len() as f32
};
let text: String = char_scores.iter().map(|(ch, _)| ch).collect();
Ok(RecognitionResult::new(text, confidence, char_scores))
}
fn is_punctuation(ch: char) -> bool {
PUNCTUATIONS.contains(&ch)
}
}
impl RecModel {
pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
Ok(self.engine.run_dynamic(input)?)
}
pub fn input_shape(&self) -> &[usize] {
self.engine.input_shape()
}
pub fn output_shape(&self) -> &[usize] {
self.engine.output_shape()
}
pub fn charset(&self) -> &[char] {
&self.charset
}
pub fn get_char(&self, index: usize) -> Option<char> {
self.charset.get(index).copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rec_options_default() {
let opts = RecOptions::default();
assert_eq!(opts.target_height, 48);
assert_eq!(opts.min_score, 0.3);
assert_eq!(opts.punct_min_score, 0.1);
assert_eq!(opts.batch_size, 8);
assert!(opts.enable_batch);
}
#[test]
fn test_rec_options_builder() {
let opts = RecOptions::new()
.with_target_height(32)
.with_min_score(0.6)
.with_punct_min_score(0.2)
.with_batch_size(16)
.with_batch(false);
assert_eq!(opts.target_height, 32);
assert_eq!(opts.min_score, 0.6);
assert_eq!(opts.punct_min_score, 0.2);
assert_eq!(opts.batch_size, 16);
assert!(!opts.enable_batch);
}
#[test]
fn test_recognition_result_new() {
let char_scores = vec![
('H', 0.99),
('e', 0.94),
('l', 0.93),
('l', 0.95),
('o', 0.94),
];
let result = RecognitionResult::new("Hello".to_string(), 0.95, char_scores.clone());
assert_eq!(result.text, "Hello");
assert_eq!(result.confidence, 0.95);
assert_eq!(result.char_scores.len(), 5);
assert_eq!(result.char_scores[0].0, 'H');
assert_eq!(result.char_scores[0].1, 0.99);
}
#[test]
fn test_recognition_result_is_valid() {
let result = RecognitionResult::new(
"Hello".to_string(),
0.95,
vec![
('H', 0.99),
('e', 0.94),
('l', 0.93),
('l', 0.95),
('o', 0.94),
],
);
assert!(result.is_valid(0.9));
assert!(result.is_valid(0.95));
assert!(!result.is_valid(0.96));
assert!(!result.is_valid(0.99));
}
#[test]
fn test_recognition_result_empty() {
let result = RecognitionResult::new(String::new(), 0.0, vec![]);
assert!(result.text.is_empty());
assert_eq!(result.confidence, 0.0);
assert!(!result.is_valid(0.1));
}
#[test]
fn test_is_punctuation_common() {
assert!(RecModel::is_punctuation(','));
assert!(RecModel::is_punctuation('.'));
assert!(RecModel::is_punctuation('!'));
assert!(RecModel::is_punctuation('?'));
assert!(RecModel::is_punctuation(';'));
assert!(RecModel::is_punctuation(':'));
assert!(RecModel::is_punctuation('"'));
assert!(RecModel::is_punctuation('\''));
}
#[test]
fn test_is_punctuation_chinese() {
assert!(RecModel::is_punctuation(','));
assert!(RecModel::is_punctuation('。'));
assert!(RecModel::is_punctuation('!'));
assert!(RecModel::is_punctuation('?'));
assert!(RecModel::is_punctuation(';'));
assert!(RecModel::is_punctuation(':'));
assert!(RecModel::is_punctuation('、'));
assert!(RecModel::is_punctuation('—'));
assert!(RecModel::is_punctuation('…'));
}
#[test]
fn test_is_punctuation_brackets() {
assert!(RecModel::is_punctuation('('));
assert!(RecModel::is_punctuation(')'));
assert!(RecModel::is_punctuation('['));
assert!(RecModel::is_punctuation(']'));
assert!(RecModel::is_punctuation('{'));
assert!(RecModel::is_punctuation('}'));
assert!(RecModel::is_punctuation('「'));
assert!(RecModel::is_punctuation('」'));
assert!(RecModel::is_punctuation('《'));
assert!(RecModel::is_punctuation('》'));
}
#[test]
fn test_is_punctuation_false() {
assert!(!RecModel::is_punctuation('A'));
assert!(!RecModel::is_punctuation('z'));
assert!(!RecModel::is_punctuation('0'));
assert!(!RecModel::is_punctuation('中'));
assert!(!RecModel::is_punctuation('文'));
assert!(!RecModel::is_punctuation(' '));
}
}