use crate::core::OCRError;
use crate::core::config::InputShape;
use crate::core::inference::{OrtInfer, TensorInput};
use crate::processors::NormalizeImage;
use image::{DynamicImage, RgbImage};
use ndarray::s;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct SLANetModelOutput {
pub structure_logits: ndarray::Array3<f32>,
pub bbox_preds: ndarray::Array3<f32>,
pub shape_info: Vec<[f32; 6]>,
}
#[derive(Debug)]
pub struct SLANetModel {
inference: OrtInfer,
normalizer: NormalizeImage,
input_shape: InputShape,
}
impl SLANetModel {
pub fn new(inference: OrtInfer, normalizer: NormalizeImage, input_shape: InputShape) -> Self {
Self {
inference,
normalizer,
input_shape,
}
}
pub fn input_shape(&self) -> &InputShape {
&self.input_shape
}
pub fn preprocess(
&self,
images: Vec<RgbImage>,
) -> Result<(ndarray::Array4<f32>, Vec<[f32; 6]>), OCRError> {
let mut shape_info_list = Vec::with_capacity(images.len());
let mut processed_tensors = Vec::with_capacity(images.len());
let needs_padding = self.input_shape.has_fixed_spatial();
let (target_h, target_w) = self.input_shape.spatial_size_or(488, 488);
let target_size = target_h.max(target_w) as f32;
tracing::debug!(
"SLANet preprocess: needs_padding={}, target_size={}",
needs_padding,
target_size
);
for img in images {
let (orig_h, orig_w) = (img.height() as f32, img.width() as f32);
let longest_side = orig_h.max(orig_w);
let scale = target_size / longest_side;
let resized_h = (orig_h * scale).round() as u32;
let resized_w = (orig_w * scale).round() as u32;
tracing::debug!(
"SLANet resize: orig={}x{}, resized={}x{}, scale={:.4}",
orig_w,
orig_h,
resized_w,
resized_h,
scale
);
let resized = image::imageops::resize(
&img,
resized_w,
resized_h,
image::imageops::FilterType::Triangle,
);
let normalized = self
.normalizer
.normalize_to(DynamicImage::ImageRgb8(resized.clone()))?;
let (tensor, pad_h, pad_w) = if needs_padding {
let target_size_u32 = target_size as usize;
let mut padded =
ndarray::Array4::<f32>::zeros((1, 3, target_size_u32, target_size_u32));
padded
.slice_mut(s![0, .., 0..(resized_h as usize), 0..(resized_w as usize)])
.assign(&normalized.slice(s![0, .., .., ..]));
let pad_h = target_size - (resized_h as f32);
let pad_w = target_size - (resized_w as f32);
(padded, pad_h, pad_w)
} else {
(normalized, 0.0, 0.0)
};
shape_info_list.push([orig_h, orig_w, scale, pad_h, pad_w, target_size]);
processed_tensors.push(tensor);
}
let batch_tensor = if processed_tensors.is_empty() {
ndarray::Array4::<f32>::zeros((0, 0, 0, 0))
} else {
let first_shape = processed_tensors[0].shape().to_vec();
let (channels, height, width) = (first_shape[1], first_shape[2], first_shape[3]);
if !processed_tensors.iter().all(|t| {
t.shape()[1] == channels && t.shape()[2] == height && t.shape()[3] == width
}) {
return Err(OCRError::InvalidInput {
message: "SLANet preprocess produced tensors with inconsistent shapes"
.to_string(),
});
}
let mut batch =
ndarray::Array4::<f32>::zeros((processed_tensors.len(), channels, height, width));
for (i, tensor) in processed_tensors.iter().enumerate() {
batch
.slice_mut(s![i, .., .., ..])
.assign(&tensor.slice(s![0, .., .., ..]));
}
batch
};
Ok((batch_tensor, shape_info_list))
}
pub fn infer(
&self,
batch_tensor: &ndarray::Array4<f32>,
) -> Result<(ndarray::Array3<f32>, ndarray::Array3<f32>), OCRError> {
let input_name = self.inference.input_name();
let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
let outputs = self
.inference
.infer(&inputs)
.map_err(|e| OCRError::Inference {
model_name: "SLANet".to_string(),
context: format!(
"failed to run inference on batch with shape {:?}",
batch_tensor.shape()
),
source: Box::new(e),
})?;
if outputs.len() < 2 {
return Err(OCRError::InvalidInput {
message: format!("SLANet: expected at least 2 outputs, got {}", outputs.len()),
});
}
let bbox_preds =
outputs[0]
.1
.clone()
.try_into_array3_f32()
.map_err(|e| OCRError::Inference {
model_name: "SLANet".to_string(),
context: "failed to convert first output (bbox_preds) to 3D array".to_string(),
source: Box::new(e),
})?;
let structure_logits =
outputs[1]
.1
.clone()
.try_into_array3_f32()
.map_err(|e| OCRError::Inference {
model_name: "SLANet".to_string(),
context: "failed to convert second output (structure_logits) to 3D array"
.to_string(),
source: Box::new(e),
})?;
Ok((bbox_preds, structure_logits))
}
pub fn forward(&self, images: Vec<RgbImage>) -> Result<SLANetModelOutput, OCRError> {
let (batch_tensor, shape_info) = self.preprocess(images)?;
let (bbox_preds, structure_logits) = self.infer(&batch_tensor)?;
Ok(SLANetModelOutput {
structure_logits,
bbox_preds,
shape_info,
})
}
}
#[derive(Debug, Default)]
pub struct SLANetModelBuilder {
input_shape: Option<InputShape>,
ort_config: Option<crate::core::config::OrtSessionConfig>,
}
impl SLANetModelBuilder {
pub fn new() -> Self {
Self {
input_shape: None,
ort_config: None,
}
}
pub fn input_size(mut self, size: (u32, u32)) -> Self {
self.input_shape = Some(InputShape::dynamic_batch(3, size.0 as i64, size.1 as i64));
self
}
pub fn input_shape(mut self, shape: InputShape) -> Self {
self.input_shape = Some(shape);
self
}
pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
self.ort_config = Some(config);
self
}
pub fn build(self, model_path: &Path) -> Result<SLANetModel, OCRError> {
let inference = if self.ort_config.is_some() {
use crate::core::config::ModelInferenceConfig;
let common_config = ModelInferenceConfig {
ort_session: self.ort_config,
..Default::default()
};
OrtInfer::from_config(&common_config, model_path, None)?
} else {
OrtInfer::new(model_path, None)?
};
let input_shape = if let Some(shape) = self.input_shape {
shape
} else if let Some(onnx_dims) = inference.primary_input_shape() {
InputShape::from_onnx_dims(&onnx_dims).unwrap_or_else(|| {
InputShape::dynamic_batch(3, 512, 512)
})
} else {
InputShape::dynamic_batch(3, 512, 512)
};
tracing::debug!(
"SLANet input shape: {} (fixed_spatial: {})",
input_shape,
input_shape.has_fixed_spatial()
);
let normalizer = NormalizeImage::with_color_order(
Some(1.0 / 255.0),
Some(vec![0.485, 0.456, 0.406]), Some(vec![0.229, 0.224, 0.225]), None,
Some(crate::processors::types::ColorOrder::BGR),
)?;
Ok(SLANetModel::new(inference, normalizer, input_shape))
}
}