use super::builder::PredictorBuilderState;
use crate::TaskPredictorBuilder;
use crate::core::OcrResult;
use crate::core::errors::OCRError;
use crate::core::traits::OrtConfigurable;
use crate::core::traits::adapter::AdapterBuilder;
use crate::core::traits::task::ImageTaskInput;
use crate::domain::adapters::TextRecognitionAdapterBuilder;
use crate::domain::tasks::text_recognition::{TextRecognitionConfig, TextRecognitionTask};
use crate::predictors::TaskPredictorCore;
use image::RgbImage;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub struct TextRecognitionResult {
pub texts: Vec<String>,
pub scores: Vec<f32>,
}
pub struct TextRecognitionPredictor {
core: TaskPredictorCore<TextRecognitionTask>,
}
impl TextRecognitionPredictor {
pub fn builder() -> TextRecognitionPredictorBuilder {
TextRecognitionPredictorBuilder::new()
}
pub fn predict(&self, images: Vec<RgbImage>) -> OcrResult<TextRecognitionResult> {
let input = ImageTaskInput::new(images);
let output = self.core.predict(input)?;
Ok(TextRecognitionResult {
texts: output.texts,
scores: output.scores,
})
}
}
#[derive(TaskPredictorBuilder)]
#[builder(config = TextRecognitionConfig)]
pub struct TextRecognitionPredictorBuilder {
state: PredictorBuilderState<TextRecognitionConfig>,
dict_path: Option<PathBuf>,
}
impl TextRecognitionPredictorBuilder {
pub fn new() -> Self {
Self {
state: PredictorBuilderState::new(TextRecognitionConfig {
score_threshold: 0.0,
max_text_length: 100,
}),
dict_path: None,
}
}
pub fn score_threshold(mut self, threshold: f32) -> Self {
self.state.config_mut().score_threshold = threshold;
self
}
pub fn max_text_length(mut self, max_length: usize) -> Self {
self.state.config_mut().max_text_length = max_length;
self
}
pub fn dict_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.dict_path = Some(path.as_ref().to_path_buf());
self
}
pub fn build<P: AsRef<Path>>(self, model_path: P) -> OcrResult<TextRecognitionPredictor> {
let Self { state, dict_path } = self;
let (config, ort_config) = state.into_parts();
let dict_path = dict_path
.ok_or_else(|| OCRError::missing_field("dict_path", "TextRecognitionPredictor"))?;
let character_dict = std::fs::read_to_string(&dict_path)?
.lines()
.map(|s| s.to_string())
.collect::<Vec<String>>();
let mut adapter_builder = TextRecognitionAdapterBuilder::new()
.with_config(config.clone())
.character_dict(character_dict);
if let Some(ort_cfg) = ort_config {
adapter_builder = adapter_builder.with_ort_config(ort_cfg);
}
let adapter = Box::new(adapter_builder.build(model_path.as_ref())?);
let task = TextRecognitionTask::new(config.clone());
Ok(TextRecognitionPredictor {
core: TaskPredictorCore::new(adapter, task, config),
})
}
}
impl Default for TextRecognitionPredictorBuilder {
fn default() -> Self {
Self::new()
}
}