use std::path::Path;
use std::sync::Mutex;
use image::DynamicImage;
use ndarray::Array2;
use ort::session::Session;
use ort::value::TensorRef;
use super::config::OcrConfig;
use super::error::{OcrError, OcrResult};
use super::postprocessor::{extract_boxes, DetectedBox};
use super::preprocessor::preprocess_for_detection;
pub struct TextDetector {
session: Mutex<Option<Session>>,
#[allow(dead_code)]
model_bytes: Option<Vec<u8>>,
config: OcrConfig,
}
impl TextDetector {
pub fn new(model_path: impl AsRef<Path>, config: OcrConfig) -> OcrResult<Self> {
let model_bytes = std::fs::read(model_path.as_ref())
.map_err(|e| OcrError::ModelLoadError(format!("Failed to read model file: {}", e)))?;
Self::from_bytes(&model_bytes, config)
}
pub fn from_bytes(model_bytes: &[u8], config: OcrConfig) -> OcrResult<Self> {
let session = Session::builder()
.map_err(|e| {
OcrError::ModelLoadError(format!("Failed to create session builder: {}", e))
})?
.with_intra_threads(config.num_threads)
.map_err(|e| OcrError::ModelLoadError(format!("Failed to set threads: {}", e)))?
.commit_from_memory(model_bytes)
.map_err(|e| OcrError::ModelLoadError(format!("Failed to load model: {}", e)))?;
Ok(Self {
session: Mutex::new(Some(session)),
model_bytes: Some(model_bytes.to_vec()),
config,
})
}
pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<DetectedBox>> {
let (input_tensor, scale) =
preprocess_for_detection(image, &self.config.det_resize_strategy)?;
let prob_map = self.run_inference(&input_tensor)?;
let boxes = extract_boxes(
prob_map.view(),
self.config.det_threshold,
self.config.box_threshold,
self.config.max_candidates,
self.config.unclip_ratio,
scale,
)?;
Ok(boxes)
}
fn run_inference(&self, input: &ndarray::Array4<f32>) -> OcrResult<Array2<f32>> {
let mut session_guard = self.session.lock().map_err(|e| {
OcrError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
let session = session_guard
.as_mut()
.ok_or_else(|| OcrError::InferenceError("Model session not initialized".to_string()))?;
let input_tensor = TensorRef::from_array_view(input).map_err(|e| {
OcrError::InferenceError(format!("Failed to create input tensor: {}", e))
})?;
let outputs = session
.run(ort::inputs!["x" => input_tensor])
.map_err(|e| OcrError::InferenceError(format!("Inference failed: {}", e)))?;
let (_, output_tensor) = outputs
.iter()
.next()
.ok_or_else(|| OcrError::InferenceError("No output tensor found".to_string()))?;
let output_array = output_tensor
.try_extract_array::<f32>()
.map_err(|e| OcrError::InferenceError(format!("Failed to extract output: {}", e)))?;
let shape = output_array.shape();
if shape.len() != 4 {
return Err(OcrError::InferenceError(format!(
"Unexpected output shape: {:?}, expected 4D tensor",
shape
)));
}
let height = shape[2];
let width = shape[3];
let mut prob_map = Array2::zeros((height, width));
for y in 0..height {
for x in 0..width {
prob_map[[y, x]] = output_array[[0, 0, y, x]];
}
}
Ok(prob_map)
}
pub fn is_loaded(&self) -> bool {
self.session.lock().map(|s| s.is_some()).unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detector_config() {
let config = OcrConfig::builder()
.det_threshold(0.4)
.box_threshold(0.6)
.build();
assert!((config.det_threshold - 0.4).abs() < f32::EPSILON);
assert!((config.box_threshold - 0.6).abs() < f32::EPSILON);
}
}