use crate::core::OCRError;
use image::imageops::{FilterType, overlay, resize};
use image::{DynamicImage, RgbImage};
use ndarray::{Array2, Array3, Array4};
use regex::Regex;
use std::sync::LazyLock;
static CHINESE_TEXT_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"\\text\s*\{([^{}]*[\u{4e00}-\u{9fff}]+[^{}]*)\}")
.unwrap_or_else(|e| panic!("Failed to compile Chinese text regex pattern: {e}"))
});
static TEXT_COMMAND_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(\\(operatorname|mathrm|text|mathbf)\s?\*?\s*\{.*?\})")
.unwrap_or_else(|e| panic!("Failed to compile text command regex pattern: {e}"))
});
static LETTER_TO_NONLETTER_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"([a-zA-Z])\s+([^a-zA-Z])")
.unwrap_or_else(|e| panic!("Failed to compile letter to nonletter regex pattern: {e}"))
});
#[derive(Debug, Clone, Copy)]
pub struct FormulaPreprocessParams {
pub target_size: (u32, u32),
pub crop_threshold: u8,
pub padding_multiple: usize,
pub normalize_mean: [f32; 3],
pub normalize_std: [f32; 3],
}
#[derive(Debug, Clone)]
pub struct FormulaPreprocessor {
params: FormulaPreprocessParams,
}
impl FormulaPreprocessor {
pub fn new(params: FormulaPreprocessParams) -> Self {
Self { params }
}
pub fn preprocess_batch(&self, images: &[RgbImage]) -> Result<ndarray::Array4<f32>, OCRError> {
let mut normalized = Vec::with_capacity(images.len());
for image in images {
let cropped = self.crop_margin(image);
let resized = self.resize_and_pad(&cropped);
let normalized_image = self.normalize_and_to_grayscale(&resized);
normalized.push(normalized_image);
}
self.format_to_tensor(normalized)
}
fn crop_margin(&self, img: &RgbImage) -> RgbImage {
let gray = DynamicImage::ImageRgb8(img.clone()).to_luma8();
let (width, height) = gray.dimensions();
let mut min_val = u8::MAX;
let mut max_val = u8::MIN;
for pixel in gray.pixels() {
let val = pixel[0];
min_val = min_val.min(val);
max_val = max_val.max(val);
}
if max_val == min_val {
return img.clone();
}
let mut binary = image::GrayImage::new(width, height);
for (x, y, pixel) in gray.enumerate_pixels() {
let normalized = ((pixel[0] as f32 - min_val as f32)
/ (max_val as f32 - min_val as f32)
* 255.0) as u8;
binary.put_pixel(
x,
y,
image::Luma([if normalized < self.params.crop_threshold {
255
} else {
0
}]),
);
}
let mut min_x = width;
let mut min_y = height;
let mut max_x = 0;
let mut max_y = 0;
for (x, y, pixel) in binary.enumerate_pixels() {
if pixel[0] > 0 {
min_x = min_x.min(x);
min_y = min_y.min(y);
max_x = max_x.max(x);
max_y = max_y.max(y);
}
}
if min_x >= max_x || min_y >= max_y {
return img.clone();
}
image::imageops::crop_imm(img, min_x, min_y, max_x - min_x + 1, max_y - min_y + 1)
.to_image()
}
fn resize_and_pad(&self, img: &RgbImage) -> RgbImage {
let (target_width, target_height) = self.params.target_size;
let (img_width, img_height) = img.dimensions();
if img_width == 0 || img_height == 0 {
return RgbImage::new(target_width, target_height);
}
let min_size = target_width.min(target_height);
let scale = (min_size as f32) / (img_width.max(img_height) as f32);
let new_width = (img_width as f32 * scale) as u32;
let new_height = (img_height as f32 * scale) as u32;
let final_width = new_width.min(target_width);
let final_height = new_height.min(target_height);
let resized = resize(img, final_width, final_height, FilterType::Triangle);
let delta_width = target_width - final_width;
let delta_height = target_height - final_height;
let pad_left = delta_width / 2;
let pad_top = delta_height / 2;
let mut padded = RgbImage::from_pixel(target_width, target_height, image::Rgb([0, 0, 0]));
overlay(&mut padded, &resized, pad_left as i64, pad_top as i64);
padded
}
fn normalize_and_to_grayscale(&self, img: &RgbImage) -> Array3<f32> {
let (width, height) = img.dimensions();
const SCALE: f32 = 1.0 / 255.0;
let mean = self.params.normalize_mean;
let std = self.params.normalize_std;
let mut normalized = Array3::<f32>::zeros((height as usize, width as usize, 3));
for (x, y, pixel) in img.enumerate_pixels() {
let r = pixel[0] as f32;
let g = pixel[1] as f32;
let b = pixel[2] as f32;
normalized[[y as usize, x as usize, 0]] =
(b * SCALE - mean[0]) / std[0].max(f32::EPSILON);
normalized[[y as usize, x as usize, 1]] =
(g * SCALE - mean[1]) / std[1].max(f32::EPSILON);
normalized[[y as usize, x as usize, 2]] =
(r * SCALE - mean[2]) / std[2].max(f32::EPSILON);
}
let mut grayscale = Array2::<f32>::zeros((height as usize, width as usize));
for y in 0..height as usize {
for x in 0..width as usize {
let b = normalized[[y, x, 0]];
let g = normalized[[y, x, 1]];
let r = normalized[[y, x, 2]];
grayscale[[y, x]] = 0.114 * b + 0.587 * g + 0.299 * r;
}
}
let mut result = Array3::<f32>::zeros((height as usize, width as usize, 3));
for y in 0..height as usize {
for x in 0..width as usize {
let gray_val = grayscale[[y, x]];
result[[y, x, 0]] = gray_val;
result[[y, x, 1]] = gray_val;
result[[y, x, 2]] = gray_val;
}
}
result
}
fn format_to_tensor(&self, images: Vec<Array3<f32>>) -> Result<ndarray::Array4<f32>, OCRError> {
let (target_width, target_height) = self.params.target_size;
let batch_size = images.len();
let padded_height = ((target_height as f32 / self.params.padding_multiple as f32).ceil()
* self.params.padding_multiple as f32) as usize;
let padded_width = ((target_width as f32 / self.params.padding_multiple as f32).ceil()
* self.params.padding_multiple as f32) as usize;
let padding_value = 1.0_f32;
let mut tensor =
Array4::<f32>::from_elem((batch_size, 1, padded_height, padded_width), padding_value);
for (batch_idx, img) in images.iter().enumerate() {
for y in 0..target_height as usize {
for x in 0..target_width as usize {
tensor[[batch_idx, 0, y, x]] = img[[y, x, 0]];
}
}
}
Ok(tensor)
}
}
pub fn normalize_latex(latex: &str) -> String {
let mut result = latex.to_string();
result = CHINESE_TEXT_PATTERN.replace_all(&result, "$1").to_string();
result = result.replace('"', "");
let mut names = Vec::new();
for mat in TEXT_COMMAND_PATTERN.find_iter(&result) {
let text = mat.as_str();
let cleaned = text.replace(" ", "");
names.push(cleaned);
}
if !names.is_empty() {
let mut names_iter = names.into_iter();
result = TEXT_COMMAND_PATTERN
.replace_all(&result, |_: ®ex::Captures| {
names_iter.next().unwrap_or_default()
})
.to_string();
}
let mut prev_result = String::new();
let max_iterations = 10;
let mut iterations = 0;
while prev_result != result && iterations < max_iterations {
prev_result = result.clone();
let mut temp = String::new();
let chars: Vec<char> = result.chars().collect();
let mut i = 0;
while i < chars.len() {
if i + 2 < chars.len()
&& chars[i] == '\\'
&& chars[i + 1] == '\\'
&& chars[i + 2] == ' '
{
temp.push(chars[i]);
temp.push(chars[i + 1]);
temp.push(chars[i + 2]);
i += 3;
} else if i + 1 < chars.len() && chars[i + 1].is_whitespace() {
let is_noletter_current = !chars[i].is_ascii_alphabetic();
let mut j = i + 1;
while j < chars.len() && chars[j].is_whitespace() {
j += 1;
}
if j < chars.len() {
let is_noletter_next = !chars[j].is_ascii_alphabetic();
if is_noletter_current && is_noletter_next {
temp.push(chars[i]);
i = j;
} else if is_noletter_current && chars[j].is_ascii_alphabetic() {
temp.push(chars[i]);
i = j;
} else {
temp.push(chars[i]);
i += 1;
}
} else {
temp.push(chars[i]);
i += 1;
}
} else {
temp.push(chars[i]);
i += 1;
}
}
result = temp;
result = LETTER_TO_NONLETTER_PATTERN
.replace_all(&result, "$1$2")
.to_string();
iterations += 1;
}
result.trim().to_string()
}