use std::path::PathBuf;
use std::time::Instant;
use image::RgbImage;
use crate::layout::error::LayoutError;
use crate::layout::model_manager::LayoutModelManager;
use crate::layout::models::LayoutModel;
use crate::layout::models::rtdetr::RtDetrModel;
use crate::layout::models::yolo::{YoloModel, YoloVariant};
use crate::layout::postprocessing::heuristics;
use crate::layout::types::DetectionResult;
#[derive(Debug, Clone)]
pub enum ModelBackend {
YoloDocLayNet,
RtDetr,
Custom { path: PathBuf, variant: CustomModelVariant },
}
#[derive(Debug, Clone)]
pub enum CustomModelVariant {
RtDetr,
YoloDocLayNet,
YoloDocStructBench,
Yolox { input_width: u32, input_height: u32 },
}
#[derive(Debug, Clone)]
pub struct LayoutEngineConfig {
pub backend: ModelBackend,
pub confidence_threshold: Option<f32>,
pub apply_heuristics: bool,
pub cache_dir: Option<PathBuf>,
pub acceleration: Option<crate::core::config::acceleration::AccelerationConfig>,
}
impl Default for LayoutEngineConfig {
fn default() -> Self {
Self {
backend: ModelBackend::RtDetr,
confidence_threshold: None,
apply_heuristics: true,
cache_dir: None,
acceleration: None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DetectTimings {
pub preprocess_ms: f64,
pub onnx_ms: f64,
pub model_total_ms: f64,
pub postprocess_ms: f64,
}
pub struct LayoutEngine {
model: Box<dyn LayoutModel>,
config: LayoutEngineConfig,
}
impl LayoutEngine {
pub fn from_config(config: LayoutEngineConfig) -> Result<Self, LayoutError> {
crate::ort_discovery::ensure_ort_available();
let model: Box<dyn LayoutModel> = match &config.backend {
ModelBackend::YoloDocLayNet => {
return Err(LayoutError::ModelDownload(
"YOLO DocLayNet model is not available for automatic download. \
Use ModelBackend::Custom with a local YOLO ONNX file instead."
.into(),
));
}
ModelBackend::RtDetr => {
let manager = LayoutModelManager::new(config.cache_dir.clone());
let model_path = manager.ensure_rtdetr_model()?;
let path_str = model_path.to_string_lossy();
Box::new(RtDetrModel::from_file(&path_str, config.acceleration.as_ref())?)
}
ModelBackend::Custom { path, variant } => {
let path_str = path.to_string_lossy();
let accel = config.acceleration.as_ref();
match variant {
CustomModelVariant::RtDetr => Box::new(RtDetrModel::from_file(&path_str, accel)?),
CustomModelVariant::YoloDocLayNet => Box::new(YoloModel::from_file(
&path_str,
YoloVariant::DocLayNet,
640,
640,
"Custom-YOLO-DocLayNet",
accel,
)?),
CustomModelVariant::YoloDocStructBench => Box::new(YoloModel::from_file(
&path_str,
YoloVariant::DocStructBench,
1024,
1024,
"Custom-DocLayout-YOLO",
accel,
)?),
CustomModelVariant::Yolox {
input_width,
input_height,
} => Box::new(YoloModel::from_file(
&path_str,
YoloVariant::Yolox,
*input_width,
*input_height,
"Custom-YOLOX",
accel,
)?),
}
}
};
Ok(Self { model, config })
}
pub fn detect(&mut self, img: &RgbImage) -> Result<DetectionResult, LayoutError> {
let (result, _timings) = self.detect_timed(img)?;
for detection in &result.detections {
tracing::trace!(class = ?detection.class, confidence = detection.confidence, "Layout detection result");
}
Ok(result)
}
pub fn detect_timed(&mut self, img: &RgbImage) -> Result<(DetectionResult, DetectTimings), LayoutError> {
let model_start = Instant::now();
let mut detections = if let Some(threshold) = self.config.confidence_threshold {
self.model.detect_with_threshold(img, threshold)?
} else {
self.model.detect(img)?
};
let model_total_ms = model_start.elapsed().as_secs_f64() * 1000.0;
let (preprocess_ms, onnx_ms) = crate::layout::inference_timings::take();
let page_width = img.width();
let page_height = img.height();
let postprocess_start = Instant::now();
if self.config.apply_heuristics {
heuristics::apply_heuristics(&mut detections, page_width as f32, page_height as f32);
}
let postprocess_ms = postprocess_start.elapsed().as_secs_f64() * 1000.0;
tracing::info!(
preprocess_ms,
onnx_ms,
model_total_ms,
postprocess_ms,
final_detections = detections.len(),
"Layout engine detect_timed() breakdown"
);
let timings = DetectTimings {
preprocess_ms,
onnx_ms,
model_total_ms,
postprocess_ms,
};
Ok((DetectionResult::new(page_width, page_height, detections), timings))
}
pub fn detect_batch(&mut self, images: &[&RgbImage]) -> Result<Vec<(DetectionResult, DetectTimings)>, LayoutError> {
if images.is_empty() {
return Ok(Vec::new());
}
let model_start = Instant::now();
let per_image_detections = self.model.detect_batch(images, self.config.confidence_threshold)?;
let model_total_ms = model_start.elapsed().as_secs_f64() * 1000.0;
let (preprocess_ms, onnx_ms) = crate::layout::inference_timings::take();
let postprocess_start = Instant::now();
let mut results = Vec::with_capacity(images.len());
for (img, mut detections) in images.iter().zip(per_image_detections) {
let page_width = img.width();
let page_height = img.height();
if self.config.apply_heuristics {
heuristics::apply_heuristics(&mut detections, page_width as f32, page_height as f32);
}
results.push((
DetectionResult::new(page_width, page_height, detections),
DetectTimings {
preprocess_ms,
onnx_ms,
model_total_ms,
postprocess_ms: 0.0, },
));
}
let postprocess_ms = postprocess_start.elapsed().as_secs_f64() * 1000.0;
let postprocess_ms_per = postprocess_ms / images.len() as f64;
for (_, timings) in &mut results {
timings.postprocess_ms = postprocess_ms_per;
}
tracing::info!(
preprocess_ms,
onnx_ms,
model_total_ms,
postprocess_ms,
batch_size = images.len(),
total_detections = results.iter().map(|(r, _)| r.detections.len()).sum::<usize>(),
"Layout engine detect_batch() breakdown"
);
Ok(results)
}
pub fn model_name(&self) -> &str {
self.model.name()
}
pub fn config(&self) -> &LayoutEngineConfig {
&self.config
}
}