use crate::core::OCRError;
use crate::core::inference::{OrtInfer, TensorInput};
use crate::processors::{FormulaPreprocessParams, FormulaPreprocessor};
use image::RgbImage;
use ndarray::{ArrayBase, Axis, Data, Ix2};
#[derive(Debug, Clone)]
pub struct PPFormulaNetPreprocessConfig {
pub target_size: (u32, u32),
pub crop_threshold: u8,
pub padding_multiple: usize,
pub normalize_mean: [f32; 3],
pub normalize_std: [f32; 3],
}
impl Default for PPFormulaNetPreprocessConfig {
fn default() -> Self {
Self {
target_size: (384, 384),
crop_threshold: 200,
padding_multiple: 16,
normalize_mean: [0.7931, 0.7931, 0.7931],
normalize_std: [0.1738, 0.1738, 0.1738],
}
}
}
#[derive(Debug, Clone)]
pub struct PPFormulaNetPostprocessConfig {
pub sos_token_id: i64,
pub eos_token_id: i64,
}
impl Default for PPFormulaNetPostprocessConfig {
fn default() -> Self {
Self {
sos_token_id: 0,
eos_token_id: 2,
}
}
}
#[derive(Debug, Clone)]
pub struct PPFormulaNetModelOutput {
pub token_ids: ndarray::Array2<i64>,
}
#[derive(Debug)]
pub struct PPFormulaNetModel {
inference: OrtInfer,
preprocessor: FormulaPreprocessor,
_preprocess_config: PPFormulaNetPreprocessConfig,
}
impl PPFormulaNetModel {
pub fn new(
inference: OrtInfer,
preprocess_config: PPFormulaNetPreprocessConfig,
) -> Result<Self, OCRError> {
let params = FormulaPreprocessParams {
target_size: preprocess_config.target_size,
crop_threshold: preprocess_config.crop_threshold,
padding_multiple: preprocess_config.padding_multiple,
normalize_mean: preprocess_config.normalize_mean,
normalize_std: preprocess_config.normalize_std,
};
let preprocessor = FormulaPreprocessor::new(params);
Ok(Self {
inference,
preprocessor,
_preprocess_config: preprocess_config,
})
}
pub fn preprocess(&self, images: Vec<RgbImage>) -> Result<ndarray::Array4<f32>, OCRError> {
self.preprocessor.preprocess_batch(&images)
}
pub fn infer(
&self,
batch_tensor: &ndarray::Array4<f32>,
) -> Result<ndarray::Array2<i64>, OCRError> {
let input_name = self.inference.input_name();
let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
let outputs = self
.inference
.infer(&inputs)
.map_err(|e| OCRError::Inference {
model_name: "PP-FormulaNet".to_string(),
context: format!(
"failed to run inference on batch with shape {:?}",
batch_tensor.shape()
),
source: Box::new(e),
})?;
let output = outputs
.into_iter()
.next()
.ok_or_else(|| OCRError::InvalidInput {
message: "PP-FormulaNet: no output returned from inference".to_string(),
})?;
output
.1
.try_into_array2_i64()
.map_err(|e| OCRError::Inference {
model_name: "PP-FormulaNet".to_string(),
context: "failed to convert output to 2D i64 array".to_string(),
source: Box::new(e),
})
}
pub fn postprocess(
&self,
token_ids: ndarray::Array2<i64>,
_config: &PPFormulaNetPostprocessConfig,
) -> Result<PPFormulaNetModelOutput, OCRError> {
Ok(PPFormulaNetModelOutput { token_ids })
}
pub fn forward(
&self,
images: Vec<RgbImage>,
config: &PPFormulaNetPostprocessConfig,
) -> Result<PPFormulaNetModelOutput, OCRError> {
let batch_tensor = self.preprocess(images)?;
let token_ids = self.infer(&batch_tensor)?;
let output = self.postprocess(token_ids, config)?;
Ok(output)
}
pub fn filter_tokens<D>(
token_ids: &ArrayBase<D, Ix2>,
config: &PPFormulaNetPostprocessConfig,
) -> Vec<Vec<u32>>
where
D: Data<Elem = i64>,
{
let batch_size = token_ids.shape()[0];
let mut filtered_tokens = Vec::with_capacity(batch_size);
for batch_idx in 0..batch_size {
let row = token_ids.index_axis(Axis(0), batch_idx);
let tokens: Vec<u32> = row
.iter()
.copied()
.take_while(|&id| id != config.eos_token_id)
.filter(|&id| id >= 0 && id != config.sos_token_id)
.map(|id| id as u32)
.collect();
filtered_tokens.push(tokens);
}
filtered_tokens
}
}
#[derive(Debug, Default)]
pub struct PPFormulaNetModelBuilder {
preprocess_config: Option<PPFormulaNetPreprocessConfig>,
ort_config: Option<crate::core::config::OrtSessionConfig>,
}
impl PPFormulaNetModelBuilder {
pub fn new() -> Self {
Self {
preprocess_config: None,
ort_config: None,
}
}
pub fn preprocess_config(mut self, config: PPFormulaNetPreprocessConfig) -> Self {
self.preprocess_config = Some(config);
self
}
pub fn target_size(mut self, width: u32, height: u32) -> Self {
let mut config = self.preprocess_config.unwrap_or_default();
config.target_size = (width, height);
self.preprocess_config = Some(config);
self
}
pub fn padding_multiple(mut self, multiple: usize) -> Self {
let mut config = self.preprocess_config.unwrap_or_default();
config.padding_multiple = multiple;
self.preprocess_config = Some(config);
self
}
pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
self.ort_config = Some(config);
self
}
pub fn build(self, model_path: &std::path::Path) -> Result<PPFormulaNetModel, OCRError> {
let inference = if self.ort_config.is_some() {
use crate::core::config::ModelInferenceConfig;
let common_config = ModelInferenceConfig {
ort_session: self.ort_config,
..Default::default()
};
OrtInfer::from_config(&common_config, model_path, None)?
} else {
OrtInfer::new(model_path, None)?
};
let mut preprocess_config = self.preprocess_config.unwrap_or_default();
if preprocess_config.target_size == (384, 384)
&& let Some(shape) = inference.primary_input_shape()
&& shape.len() >= 4
{
let height = shape[shape.len() - 2];
let width = shape[shape.len() - 1];
if height > 0 && width > 0 {
preprocess_config.target_size = (width as u32, height as u32);
}
}
PPFormulaNetModel::new(inference, preprocess_config)
}
}