pub mod engine;
pub mod error;
pub(crate) mod inference_timings;
mod model_manager;
pub mod models;
pub mod postprocessing;
pub mod preprocessing;
pub mod session;
pub mod types;
pub use engine::{CustomModelVariant, DetectTimings, LayoutEngine, LayoutEngineConfig, ModelBackend};
pub use error::LayoutError;
pub use model_manager::LayoutModelManager;
pub use models::LayoutModel;
pub use models::rtdetr::RtDetrModel;
pub use models::yolo::{YoloModel, YoloVariant};
pub use types::{BBox, DetectionResult, LayoutClass, LayoutDetection};
use std::sync::OnceLock;
use crate::core::config::layout::LayoutDetectionConfig;
use crate::model_cache::ModelCache;
static CACHED_ENGINE: ModelCache<LayoutEngine> = ModelCache::new();
static CACHED_TATR: ModelCache<models::tatr::TatrModel> = ModelCache::new();
static TATR_TRIED: OnceLock<bool> = OnceLock::new();
pub fn config_from_extraction(layout_config: &LayoutDetectionConfig) -> LayoutEngineConfig {
LayoutEngineConfig {
backend: ModelBackend::RtDetr,
confidence_threshold: layout_config.confidence_threshold,
apply_heuristics: layout_config.apply_heuristics,
cache_dir: None,
acceleration: layout_config.acceleration.clone(),
}
}
pub fn create_engine(layout_config: &LayoutDetectionConfig) -> Result<LayoutEngine, LayoutError> {
crate::ort_discovery::ensure_ort_available();
let config = config_from_extraction(layout_config);
LayoutEngine::from_config(config)
}
pub fn take_or_create_engine(layout_config: &LayoutDetectionConfig) -> Result<LayoutEngine, LayoutError> {
CACHED_ENGINE.take_or_create(|| create_engine(layout_config))
}
pub fn return_engine(engine: LayoutEngine) {
CACHED_ENGINE.put(engine);
}
pub fn take_or_create_tatr(
accel: Option<&crate::core::config::acceleration::AccelerationConfig>,
) -> Option<models::tatr::TatrModel> {
if let Some(&false) = TATR_TRIED.get() {
return None;
}
let accel_cloned = accel.cloned();
let result = CACHED_TATR.take_or_create(|| {
crate::ort_discovery::ensure_ort_available();
let manager = LayoutModelManager::new(None);
let model_path = manager.ensure_tatr_model()?;
models::tatr::TatrModel::from_file(&model_path.to_string_lossy(), accel_cloned.as_ref())
});
match result {
Ok(model) => {
TATR_TRIED.get_or_init(|| true);
Some(model)
}
Err(e) => {
TATR_TRIED.get_or_init(|| {
tracing::warn!("TATR table structure model unavailable, table structure recognition disabled: {e}");
false
});
None
}
}
}
pub fn return_tatr(model: models::tatr::TatrModel) {
CACHED_TATR.put(model);
}
static CACHED_SLANET_WIRED: ModelCache<models::slanet::SlanetModel> = ModelCache::new();
static CACHED_SLANET_WIRELESS: ModelCache<models::slanet::SlanetModel> = ModelCache::new();
static CACHED_SLANET_PLUS: ModelCache<models::slanet::SlanetModel> = ModelCache::new();
static CACHED_TABLE_CLASSIFIER: ModelCache<models::table_classifier::TableClassifier> = ModelCache::new();
static SLANET_WIRED_TRIED: OnceLock<bool> = OnceLock::new();
static SLANET_WIRELESS_TRIED: OnceLock<bool> = OnceLock::new();
static SLANET_PLUS_TRIED: OnceLock<bool> = OnceLock::new();
static TABLE_CLASSIFIER_TRIED: OnceLock<bool> = OnceLock::new();
pub fn take_or_create_slanet(
variant: &str,
accel: Option<&crate::core::config::acceleration::AccelerationConfig>,
) -> Option<models::slanet::SlanetModel> {
let (cache, tried) = match variant {
"slanet_wired" => (&CACHED_SLANET_WIRED, &SLANET_WIRED_TRIED),
"slanet_wireless" => (&CACHED_SLANET_WIRELESS, &SLANET_WIRELESS_TRIED),
"slanet_plus" => (&CACHED_SLANET_PLUS, &SLANET_PLUS_TRIED),
_ => return None,
};
if let Some(&false) = tried.get() {
return None;
}
let accel_cloned = accel.cloned();
let result = cache.take_or_create(|| {
crate::ort_discovery::ensure_ort_available();
let manager = LayoutModelManager::new(None);
let model_path = manager.ensure_slanet_model(variant)?;
models::slanet::SlanetModel::from_file(&model_path.to_string_lossy(), accel_cloned.as_ref())
});
match result {
Ok(model) => {
tried.get_or_init(|| true);
Some(model)
}
Err(e) => {
tried.get_or_init(|| {
tracing::warn!(variant, "SLANeXT model unavailable: {e}");
false
});
None
}
}
}
pub fn return_slanet(variant: &str, model: models::slanet::SlanetModel) {
match variant {
"slanet_wired" => CACHED_SLANET_WIRED.put(model),
"slanet_wireless" => CACHED_SLANET_WIRELESS.put(model),
"slanet_plus" => CACHED_SLANET_PLUS.put(model),
_ => {}
}
}
pub fn take_or_create_table_classifier(
accel: Option<&crate::core::config::acceleration::AccelerationConfig>,
) -> Option<models::table_classifier::TableClassifier> {
if let Some(&false) = TABLE_CLASSIFIER_TRIED.get() {
return None;
}
let accel_cloned = accel.cloned();
let result = CACHED_TABLE_CLASSIFIER.take_or_create(|| {
crate::ort_discovery::ensure_ort_available();
let manager = LayoutModelManager::new(None);
let model_path = manager.ensure_table_classifier()?;
models::table_classifier::TableClassifier::from_file(&model_path.to_string_lossy(), accel_cloned.as_ref())
});
match result {
Ok(model) => {
TABLE_CLASSIFIER_TRIED.get_or_init(|| true);
Some(model)
}
Err(e) => {
TABLE_CLASSIFIER_TRIED.get_or_init(|| {
tracing::warn!("Table classifier unavailable: {e}");
false
});
None
}
}
}
pub fn return_table_classifier(model: models::table_classifier::TableClassifier) {
CACHED_TABLE_CLASSIFIER.put(model);
}