use std::path::Path;
use std::sync::Mutex;
use image::DynamicImage;
use ndarray::Array4;
use ort::session::Session;
use ort::value::TensorRef;
use super::config::OcrConfig;
use super::error::{OcrError, OcrResult};
use super::preprocessor::preprocess_for_recognition;
#[derive(Debug, Clone)]
pub struct RecognitionResult {
pub text: String,
pub confidence: f32,
pub char_confidences: Vec<f32>,
}
pub struct TextRecognizer {
session: Mutex<Option<Session>>,
#[allow(dead_code)]
model_bytes: Option<Vec<u8>>,
dictionary: Vec<char>,
config: OcrConfig,
}
impl TextRecognizer {
pub fn new(
model_path: impl AsRef<Path>,
dict_path: impl AsRef<Path>,
config: OcrConfig,
) -> OcrResult<Self> {
let model_bytes = std::fs::read(model_path.as_ref())
.map_err(|e| OcrError::ModelLoadError(format!("Failed to read model file: {}", e)))?;
let dict_content = std::fs::read_to_string(dict_path.as_ref())
.map_err(|e| OcrError::DictionaryError(format!("Failed to read dictionary: {}", e)))?;
Self::from_bytes(&model_bytes, &dict_content, config)
}
pub fn from_bytes(
model_bytes: &[u8],
dict_content: &str,
config: OcrConfig,
) -> OcrResult<Self> {
let dictionary = Self::parse_dictionary(dict_content)?;
let session = Session::builder()
.map_err(|e| {
OcrError::ModelLoadError(format!("Failed to create session builder: {}", e))
})?
.with_intra_threads(config.num_threads)
.map_err(|e| OcrError::ModelLoadError(format!("Failed to set threads: {}", e)))?
.commit_from_memory(model_bytes)
.map_err(|e| OcrError::ModelLoadError(format!("Failed to load model: {}", e)))?;
Ok(Self {
session: Mutex::new(Some(session)),
model_bytes: Some(model_bytes.to_vec()),
dictionary,
config,
})
}
fn parse_dictionary(content: &str) -> OcrResult<Vec<char>> {
let chars: Vec<char> = content
.lines()
.filter(|line| !line.is_empty())
.filter_map(|line| line.chars().next())
.collect();
if chars.is_empty() {
return Err(OcrError::DictionaryError("Dictionary is empty".to_string()));
}
let mut dict = Vec::with_capacity(chars.len() + 1);
dict.push('\0'); dict.extend(chars);
Ok(dict)
}
pub fn recognize(&self, crop: &DynamicImage) -> OcrResult<RecognitionResult> {
let input_tensor = preprocess_for_recognition(crop, self.config.rec_target_height)?;
self.run_inference(&input_tensor)
}
pub fn recognize_batch(&self, crops: &[DynamicImage]) -> OcrResult<Vec<RecognitionResult>> {
crops.iter().map(|crop| self.recognize(crop)).collect()
}
fn run_inference(&self, input: &Array4<f32>) -> OcrResult<RecognitionResult> {
let mut session_guard = self.session.lock().map_err(|e| {
OcrError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
let session = session_guard
.as_mut()
.ok_or_else(|| OcrError::InferenceError("Model session not initialized".to_string()))?;
let input_tensor = TensorRef::from_array_view(input).map_err(|e| {
OcrError::InferenceError(format!("Failed to create input tensor: {}", e))
})?;
let outputs = session
.run(ort::inputs!["x" => input_tensor])
.map_err(|e| OcrError::InferenceError(format!("Inference failed: {}", e)))?;
let (_, output_tensor) = outputs
.iter()
.next()
.ok_or_else(|| OcrError::InferenceError("No output tensor found".to_string()))?;
let output_array = output_tensor
.try_extract_array::<f32>()
.map_err(|e| OcrError::InferenceError(format!("Failed to extract output: {}", e)))?;
self.ctc_greedy_decode(&output_array)
}
fn ctc_greedy_decode(&self, output: &ndarray::ArrayViewD<f32>) -> OcrResult<RecognitionResult> {
let shape = output.shape();
let (seq_len, num_classes) = match shape.len() {
2 => (shape[0], shape[1]),
3 => (shape[1], shape[2]),
_ => {
return Err(OcrError::InferenceError(format!(
"Unexpected output shape: {:?}, expected 2D or 3D tensor",
shape
)));
},
};
let blank_idx = 0; let mut text = String::new();
let mut char_confidences = Vec::new();
let mut prev_idx = blank_idx;
for t in 0..seq_len {
let mut max_idx = 0;
let mut max_conf = f32::MIN;
for c in 0..num_classes {
let prob = if shape.len() == 3 {
output[[0, t, c]]
} else {
output[[t, c]]
};
if prob > max_conf {
max_conf = prob;
max_idx = c;
}
}
if max_idx != prev_idx && max_idx != blank_idx && max_idx < self.dictionary.len() {
let ch = self.dictionary[max_idx];
if ch != '\0' {
text.push(ch);
char_confidences.push(max_conf);
}
}
prev_idx = max_idx;
}
let confidence = if char_confidences.is_empty() {
0.0
} else {
let log_sum: f32 = char_confidences.iter().map(|c| c.ln()).sum();
(log_sum / char_confidences.len() as f32).exp()
};
Ok(RecognitionResult {
text,
confidence,
char_confidences,
})
}
pub fn dictionary(&self) -> &[char] {
&self.dictionary
}
pub fn is_loaded(&self) -> bool {
self.session.lock().map(|s| s.is_some()).unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_dictionary() {
let dict_content = "a\nb\nc\n1\n2\n3";
let dict = TextRecognizer::parse_dictionary(dict_content).unwrap();
assert_eq!(dict.len(), 7);
assert_eq!(dict[0], '\0'); assert_eq!(dict[1], 'a');
assert_eq!(dict[6], '3');
}
#[test]
fn test_parse_dictionary_empty() {
let result = TextRecognizer::parse_dictionary("");
assert!(result.is_err());
}
#[test]
fn test_recognition_result() {
let result = RecognitionResult {
text: "Hello".to_string(),
confidence: 0.95,
char_confidences: vec![0.99, 0.98, 0.92, 0.93, 0.95],
};
assert_eq!(result.text, "Hello");
assert!(result.confidence > 0.9);
assert_eq!(result.char_confidences.len(), 5);
}
}