pub mod engine;
pub mod error;
pub(crate) mod inference_timings;
mod model_manager;
pub mod models;
pub mod postprocessing;
pub mod preprocessing;
pub mod session;
pub mod types;
pub use engine::{CustomModelVariant, DetectTimings, LayoutEngine, LayoutEngineConfig, LayoutPreset, ModelBackend};
pub use error::LayoutError;
pub use model_manager::LayoutModelManager;
pub use models::LayoutModel;
pub use models::rtdetr::RtDetrModel;
pub use models::yolo::{YoloModel, YoloVariant};
pub use types::{BBox, DetectionResult, LayoutClass, LayoutDetection};
use std::sync::OnceLock;
use crate::core::config::layout::LayoutDetectionConfig;
use crate::model_cache::ModelCache;
static CACHED_ENGINE: ModelCache<LayoutEngine> = ModelCache::new();
static CACHED_TATR: ModelCache<models::tatr::TatrModel> = ModelCache::new();
static TATR_TRIED: OnceLock<bool> = OnceLock::new();
pub fn config_from_extraction(layout_config: &LayoutDetectionConfig) -> LayoutEngineConfig {
let preset: LayoutPreset = layout_config.preset.parse().unwrap_or_else(|_| {
tracing::warn!(
preset = %layout_config.preset,
"unrecognized layout preset, falling back to 'accurate'"
);
LayoutPreset::Accurate
});
let mut engine_config = LayoutEngineConfig::from_preset(preset);
engine_config.confidence_threshold = layout_config.confidence_threshold;
engine_config.apply_heuristics = layout_config.apply_heuristics;
engine_config
}
pub fn create_engine(layout_config: &LayoutDetectionConfig) -> Result<LayoutEngine, LayoutError> {
crate::ort_discovery::ensure_ort_available();
let config = config_from_extraction(layout_config);
LayoutEngine::from_config(config)
}
pub fn take_or_create_engine(layout_config: &LayoutDetectionConfig) -> Result<LayoutEngine, LayoutError> {
CACHED_ENGINE.take_or_create(|| create_engine(layout_config))
}
pub fn return_engine(engine: LayoutEngine) {
CACHED_ENGINE.put(engine);
}
pub fn take_or_create_tatr() -> Option<models::tatr::TatrModel> {
if let Some(&false) = TATR_TRIED.get() {
return None;
}
let result = CACHED_TATR.take_or_create(|| {
crate::ort_discovery::ensure_ort_available();
let manager = LayoutModelManager::new(None);
let model_path = manager.ensure_tatr_model()?;
models::tatr::TatrModel::from_file(&model_path.to_string_lossy())
});
match result {
Ok(model) => {
TATR_TRIED.get_or_init(|| true);
Some(model)
}
Err(e) => {
TATR_TRIED.get_or_init(|| {
tracing::warn!("TATR table structure model unavailable, table structure recognition disabled: {e}");
false
});
None
}
}
}
pub fn return_tatr(model: models::tatr::TatrModel) {
CACHED_TATR.put(model);
}
static CACHED_SLANET_WIRED: ModelCache<models::slanet::SlanetModel> = ModelCache::new();
static CACHED_SLANET_WIRELESS: ModelCache<models::slanet::SlanetModel> = ModelCache::new();
static CACHED_SLANET_PLUS: ModelCache<models::slanet::SlanetModel> = ModelCache::new();
static CACHED_TABLE_CLASSIFIER: ModelCache<models::table_classifier::TableClassifier> = ModelCache::new();
static SLANET_WIRED_TRIED: OnceLock<bool> = OnceLock::new();
static SLANET_WIRELESS_TRIED: OnceLock<bool> = OnceLock::new();
static SLANET_PLUS_TRIED: OnceLock<bool> = OnceLock::new();
static TABLE_CLASSIFIER_TRIED: OnceLock<bool> = OnceLock::new();
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TableModelBackend {
Tatr,
SlanetWired,
SlanetWireless,
SlanetPlus,
SlanetAuto,
}
impl TableModelBackend {
pub fn from_config(s: Option<&str>) -> Self {
match s {
Some("slanet_wired") => Self::SlanetWired,
Some("slanet_wireless") => Self::SlanetWireless,
Some("slanet_plus") => Self::SlanetPlus,
Some("slanet_auto") => Self::SlanetAuto,
Some("tatr") | None => Self::Tatr,
Some(unknown) => {
tracing::warn!(table_model = unknown, "Unknown table model, falling back to TATR");
Self::Tatr
}
}
}
}
pub fn take_or_create_slanet(variant: &str) -> Option<models::slanet::SlanetModel> {
let (cache, tried) = match variant {
"slanet_wired" => (&CACHED_SLANET_WIRED, &SLANET_WIRED_TRIED),
"slanet_wireless" => (&CACHED_SLANET_WIRELESS, &SLANET_WIRELESS_TRIED),
"slanet_plus" => (&CACHED_SLANET_PLUS, &SLANET_PLUS_TRIED),
_ => return None,
};
if let Some(&false) = tried.get() {
return None;
}
let result = cache.take_or_create(|| {
crate::ort_discovery::ensure_ort_available();
let manager = LayoutModelManager::new(None);
let model_path = manager.ensure_slanet_model(variant)?;
models::slanet::SlanetModel::from_file(&model_path.to_string_lossy())
});
match result {
Ok(model) => {
tried.get_or_init(|| true);
Some(model)
}
Err(e) => {
tried.get_or_init(|| {
tracing::warn!(variant, "SLANeXT model unavailable: {e}");
false
});
None
}
}
}
pub fn return_slanet(variant: &str, model: models::slanet::SlanetModel) {
match variant {
"slanet_wired" => CACHED_SLANET_WIRED.put(model),
"slanet_wireless" => CACHED_SLANET_WIRELESS.put(model),
"slanet_plus" => CACHED_SLANET_PLUS.put(model),
_ => {}
}
}
pub fn take_or_create_table_classifier() -> Option<models::table_classifier::TableClassifier> {
if let Some(&false) = TABLE_CLASSIFIER_TRIED.get() {
return None;
}
let result = CACHED_TABLE_CLASSIFIER.take_or_create(|| {
crate::ort_discovery::ensure_ort_available();
let manager = LayoutModelManager::new(None);
let model_path = manager.ensure_table_classifier()?;
models::table_classifier::TableClassifier::from_file(&model_path.to_string_lossy())
});
match result {
Ok(model) => {
TABLE_CLASSIFIER_TRIED.get_or_init(|| true);
Some(model)
}
Err(e) => {
TABLE_CLASSIFIER_TRIED.get_or_init(|| {
tracing::warn!("Table classifier unavailable: {e}");
false
});
None
}
}
}
pub fn return_table_classifier(model: models::table_classifier::TableClassifier) {
CACHED_TABLE_CLASSIFIER.put(model);
}