use std::path::Path;
use image::DynamicImage;
use ndarray::Array2;
use super::backend::{build_backend, InferenceBackend};
use super::config::OcrConfig;
use super::error::{OcrError, OcrResult};
use super::postprocessor::{extract_boxes, DetectedBox};
use super::preprocessor::preprocess_for_detection;
pub struct TextDetector {
backend: Box<dyn InferenceBackend>,
config: OcrConfig,
}
impl TextDetector {
pub fn new(model_path: impl AsRef<Path>, config: OcrConfig) -> OcrResult<Self> {
let model_bytes = std::fs::read(model_path.as_ref())
.map_err(|e| OcrError::ModelLoadError(format!("Failed to read model file: {}", e)))?;
Self::from_bytes(&model_bytes, config)
}
pub fn from_bytes(model_bytes: &[u8], config: OcrConfig) -> OcrResult<Self> {
let backend = build_backend(model_bytes, config.num_threads)?;
Ok(Self { backend, config })
}
pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<DetectedBox>> {
let (input_tensor, scale) =
preprocess_for_detection(image, &self.config.det_resize_strategy)?;
let prob_map = self.run_inference(&input_tensor)?;
let boxes = extract_boxes(
prob_map.view(),
self.config.det_threshold,
self.config.box_threshold,
self.config.max_candidates,
self.config.unclip_ratio,
scale,
)?;
Ok(boxes)
}
fn run_inference(&self, input: &ndarray::Array4<f32>) -> OcrResult<Array2<f32>> {
let output_array = self.backend.run(input)?;
let shape = output_array.shape();
if shape.len() != 4 {
return Err(OcrError::InferenceError(format!(
"Unexpected output shape: {:?}, expected 4D tensor",
shape
)));
}
let height = shape[2];
let width = shape[3];
let mut prob_map = Array2::zeros((height, width));
for y in 0..height {
for x in 0..width {
prob_map[[y, x]] = output_array[[0, 0, y, x]];
}
}
Ok(prob_map)
}
pub fn is_loaded(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detector_config() {
let config = OcrConfig::builder()
.det_threshold(0.4)
.box_threshold(0.6)
.build();
assert!((config.det_threshold - 0.4).abs() < f32::EPSILON);
assert!((config.box_threshold - 0.6).abs() < f32::EPSILON);
}
}