use super::builder::PredictorBuilderState;
use crate::TaskPredictorBuilder;
use crate::core::OcrResult;
use crate::core::errors::OCRError;
use crate::core::traits::OrtConfigurable;
use crate::core::traits::adapter::AdapterBuilder;
use crate::core::traits::task::ImageTaskInput;
use crate::domain::adapters::{PPFormulaNetAdapterBuilder, UniMERNetAdapterBuilder};
use crate::domain::tasks::formula_recognition::{FormulaRecognitionConfig, FormulaRecognitionTask};
use crate::predictors::TaskPredictorCore;
use image::RgbImage;
use std::path::{Path, PathBuf};
#[derive(Clone, Debug)]
pub enum FormulaModelKind {
UniMERNet,
PPFormulaNet,
}
impl FormulaModelKind {
pub fn from_model_name(name: &str) -> Self {
match name {
"UniMERNet" => FormulaModelKind::UniMERNet,
"PP-FormulaNet-S"
| "PP-FormulaNet-L"
| "PP-FormulaNet_plus-S"
| "PP-FormulaNet_plus-M"
| "PP-FormulaNet_plus-L" => FormulaModelKind::PPFormulaNet,
_ => {
let name_lower = name.to_lowercase();
if name_lower.contains("unimernet") {
FormulaModelKind::UniMERNet
} else if name_lower.contains("pp-formulanet")
|| name_lower.contains("ppformulanet")
{
FormulaModelKind::PPFormulaNet
} else {
FormulaModelKind::UniMERNet
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct FormulaRecognitionResult {
pub formulas: Vec<String>,
pub scores: Vec<Option<f32>>,
}
pub struct FormulaRecognitionPredictor {
core: TaskPredictorCore<FormulaRecognitionTask>,
}
impl FormulaRecognitionPredictor {
pub fn builder() -> FormulaRecognitionPredictorBuilder {
FormulaRecognitionPredictorBuilder::new()
}
pub fn predict(&self, images: Vec<RgbImage>) -> OcrResult<FormulaRecognitionResult> {
let input = ImageTaskInput::new(images);
let output = self.core.predict(input)?;
Ok(FormulaRecognitionResult {
formulas: output.formulas,
scores: output.scores,
})
}
}
#[derive(TaskPredictorBuilder)]
#[builder(config = FormulaRecognitionConfig)]
pub struct FormulaRecognitionPredictorBuilder {
state: PredictorBuilderState<FormulaRecognitionConfig>,
model_name: String,
tokenizer_path: Option<PathBuf>,
target_size: Option<(u32, u32)>,
model_kind: Option<FormulaModelKind>,
}
impl FormulaRecognitionPredictorBuilder {
pub fn new() -> Self {
Self {
state: PredictorBuilderState::new(FormulaRecognitionConfig {
score_threshold: 0.0,
max_length: 1536,
}),
model_name: "FormulaRecognition".to_string(),
tokenizer_path: None,
target_size: None,
model_kind: None,
}
}
pub fn score_threshold(mut self, threshold: f32) -> Self {
self.state.config_mut().score_threshold = threshold;
self
}
pub fn max_length(mut self, max: usize) -> Self {
self.state.config_mut().max_length = max;
self
}
pub fn model_name(mut self, name: &str) -> Self {
self.model_name = name.to_string();
self
}
pub fn tokenizer_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.tokenizer_path = Some(path.as_ref().to_path_buf());
self
}
pub fn target_size(mut self, width: u32, height: u32) -> Self {
self.target_size = Some((width, height));
self
}
pub fn model_kind(mut self, kind: FormulaModelKind) -> Self {
self.model_kind = Some(kind);
self
}
pub fn build<P: AsRef<Path>>(self, model_path: P) -> OcrResult<FormulaRecognitionPredictor> {
let Self {
state,
model_name,
tokenizer_path,
target_size,
model_kind,
} = self;
let (config, ort_config) = state.into_parts();
let tokenizer_path = tokenizer_path.ok_or_else(|| {
OCRError::missing_field("tokenizer_path", "FormulaRecognitionPredictor")
})?;
let model_kind =
model_kind.unwrap_or_else(|| FormulaModelKind::from_model_name(&model_name));
let adapter = match model_kind {
FormulaModelKind::UniMERNet => {
let mut builder = UniMERNetAdapterBuilder::new()
.with_config(config.clone())
.model_name(&model_name)
.tokenizer_path(tokenizer_path);
if let Some((width, height)) = target_size {
builder = builder.target_size(width, height);
}
if let Some(ort_cfg) = ort_config.clone() {
builder = builder.with_ort_config(ort_cfg);
}
Box::new(builder.build(model_path.as_ref())?)
}
FormulaModelKind::PPFormulaNet => {
let mut builder = PPFormulaNetAdapterBuilder::new()
.with_config(config.clone())
.model_name(&model_name)
.tokenizer_path(tokenizer_path);
if let Some((width, height)) = target_size {
builder = builder.target_size(width, height);
}
if let Some(ort_cfg) = ort_config.clone() {
builder = builder.with_ort_config(ort_cfg);
}
Box::new(builder.build(model_path.as_ref())?)
}
};
Ok(FormulaRecognitionPredictor {
core: TaskPredictorCore::new(
adapter,
FormulaRecognitionTask::new(config.clone()),
config,
),
})
}
}
impl Default for FormulaRecognitionPredictorBuilder {
fn default() -> Self {
Self::new()
}
}