#![allow(
dead_code,
clippy::unused_self,
clippy::unnecessary_wraps,
clippy::ptr_arg
)]
use std::path::Path;
use image::GenericImageView;
use ndarray::Array4;
#[cfg(feature = "docling-ffi")]
use ort::{
session::{Session, builder::SessionBuilder},
value::Tensor,
};
use crate::error::{Result, TransmutationError};
use crate::ml::{DocumentModel, preprocessing};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum LayoutLabel {
Text,
Title,
SectionHeader,
ListItem,
Caption,
Footnote,
PageHeader,
PageFooter,
Table,
Figure,
Formula,
Code,
}
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub struct DetectedRegion {
pub label: LayoutLabel,
pub bbox: (f32, f32, f32, f32), pub confidence: f32,
}
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub struct LayoutPrediction {
pub regions: Vec<DetectedRegion>,
pub page_width: u32,
pub page_height: u32,
}
#[derive(Debug)]
pub struct LayoutModel {
#[cfg(feature = "docling-ffi")]
session: Session,
model_path: std::path::PathBuf,
}
impl LayoutModel {
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
let model_path = model_path.as_ref().to_path_buf();
#[cfg(feature = "docling-ffi")]
{
let session = SessionBuilder::new()?
.with_intra_threads(4)?
.commit_from_file(&model_path)
.map_err(|e| TransmutationError::EngineError {
engine: "layout-model".to_string(),
message: format!("Failed to load ONNX model: {e}"),
source: None,
})?;
Ok(Self {
session,
model_path,
})
}
#[cfg(not(feature = "docling-ffi"))]
{
Err(TransmutationError::EngineError(
"layout-model".to_string(),
"docling-ffi feature not enabled".to_string(),
))
}
}
#[cfg(feature = "docling-ffi")]
fn run_inference(&mut self, input: &Array4<f32>) -> Result<Vec<DetectedRegion>> {
let shape = input.shape().to_vec();
let data = input.iter().copied().collect::<Vec<f32>>();
let input_tensor = Tensor::from_array((shape, data))?;
let (output_data, output_shape) = {
let outputs = self.session.run(ort::inputs![input_tensor])?;
let output_value = &outputs[0];
let (shape, data) = output_value.try_extract_tensor::<f32>()?;
(data.to_vec(), shape.to_vec())
};
self.post_process_output_from_data(&output_shape, &output_data)
}
#[cfg(feature = "docling-ffi")]
fn post_process_output_from_data(
&self,
shape: &[i64],
data: &[f32],
) -> Result<Vec<DetectedRegion>> {
if shape.len() != 4 {
return Err(crate::TransmutationError::EngineError {
engine: "layout-model".to_string(),
message: format!("Expected 4D output tensor, got {}D", shape.len()),
source: None,
});
}
let num_classes = shape[1] as usize;
let height = shape[2] as usize;
let width = shape[3] as usize;
use ndarray::Array4;
let masks_array = Array4::from_shape_vec((1, num_classes, height, width), data.to_vec())
.map_err(|e| crate::TransmutationError::EngineError {
engine: "layout-model".to_string(),
message: format!("Failed to reshape tensor: {e}"),
source: None,
})?;
let mut all_regions = Vec::new();
for class_id in 0..num_classes {
let class_mask = masks_array.slice(ndarray::s![0, class_id, .., ..]);
let regions = self.mask_to_regions(&class_mask, class_id, width, height)?;
all_regions.extend(regions);
}
let filtered_regions = self.apply_nms(all_regions, 0.5)?;
Ok(filtered_regions)
}
#[cfg(feature = "docling-ffi")]
fn mask_to_regions(
&self,
mask: &ndarray::ArrayView2<f32>,
class_id: usize,
width: usize,
height: usize,
) -> Result<Vec<DetectedRegion>> {
let threshold = 0.5; let mut regions = Vec::new();
let mut visited = vec![vec![false; width]; height];
for y in 0..height {
for x in 0..width {
if mask[[y, x]] > threshold && !visited[y][x] {
let bbox =
self.flood_fill_bbox(mask, &mut visited, x, y, width, height, threshold);
if let Some((x0, y0, x1, y1)) = bbox {
if let Some(label) = self.class_id_to_label(class_id) {
let confidence = self.calculate_region_confidence(mask, x0, y0, x1, y1);
regions.push(DetectedRegion {
label,
bbox: (x0 as f32, y0 as f32, x1 as f32, y1 as f32),
confidence,
});
}
}
}
}
}
Ok(regions)
}
#[cfg(feature = "docling-ffi")]
fn flood_fill_bbox(
&self,
mask: &ndarray::ArrayView2<f32>,
visited: &mut Vec<Vec<bool>>,
start_x: usize,
start_y: usize,
width: usize,
height: usize,
threshold: f32,
) -> Option<(usize, usize, usize, usize)> {
let mut stack = vec![(start_x, start_y)];
let mut min_x = start_x;
let mut min_y = start_y;
let mut max_x = start_x;
let mut max_y = start_y;
while let Some((x, y)) = stack.pop() {
if x >= width || y >= height || visited[y][x] || mask[[y, x]] <= threshold {
continue;
}
visited[y][x] = true;
min_x = min_x.min(x);
min_y = min_y.min(y);
max_x = max_x.max(x);
max_y = max_y.max(y);
if x > 0 {
stack.push((x - 1, y));
}
if x + 1 < width {
stack.push((x + 1, y));
}
if y > 0 {
stack.push((x, y - 1));
}
if y + 1 < height {
stack.push((x, y + 1));
}
}
if (max_x - min_x) < 5 || (max_y - min_y) < 5 {
return None;
}
Some((min_x, min_y, max_x, max_y))
}
#[cfg(feature = "docling-ffi")]
fn calculate_region_confidence(
&self,
mask: &ndarray::ArrayView2<f32>,
x0: usize,
y0: usize,
x1: usize,
y1: usize,
) -> f32 {
let mut sum = 0.0;
let mut count = 0;
for y in y0..=y1 {
for x in x0..=x1 {
if y < mask.shape()[0] && x < mask.shape()[1] {
sum += mask[[y, x]];
count += 1;
}
}
}
if count > 0 { sum / count as f32 } else { 0.0 }
}
#[cfg(feature = "docling-ffi")]
fn apply_nms(
&self,
mut regions: Vec<DetectedRegion>,
iou_threshold: f32,
) -> Result<Vec<DetectedRegion>> {
regions.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
let mut keep = Vec::new();
let mut suppressed = vec![false; regions.len()];
for i in 0..regions.len() {
if suppressed[i] {
continue;
}
keep.push(regions[i].clone());
for j in (i + 1)..regions.len() {
if suppressed[j] {
continue;
}
let iou = self.calculate_iou(®ions[i].bbox, ®ions[j].bbox);
if iou > iou_threshold {
suppressed[j] = true;
}
}
}
Ok(keep)
}
#[cfg(feature = "docling-ffi")]
fn calculate_iou(&self, bbox1: &(f32, f32, f32, f32), bbox2: &(f32, f32, f32, f32)) -> f32 {
let (x1_min, y1_min, x1_max, y1_max) = bbox1;
let (x2_min, y2_min, x2_max, y2_max) = bbox2;
let inter_x_min = x1_min.max(*x2_min);
let inter_y_min = y1_min.max(*y2_min);
let inter_x_max = x1_max.min(*x2_max);
let inter_y_max = y1_max.min(*y2_max);
if inter_x_max <= inter_x_min || inter_y_max <= inter_y_min {
return 0.0;
}
let inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min);
let area1 = (x1_max - x1_min) * (y1_max - y1_min);
let area2 = (x2_max - x2_min) * (y2_max - y2_min);
let union_area = area1 + area2 - inter_area;
if union_area > 0.0 {
inter_area / union_area
} else {
0.0
}
}
#[cfg(feature = "docling-ffi")]
fn class_id_to_label(&self, class_id: usize) -> Option<LayoutLabel> {
match class_id {
0 => Some(LayoutLabel::Text),
1 => Some(LayoutLabel::Title),
2 => Some(LayoutLabel::SectionHeader),
3 => Some(LayoutLabel::ListItem),
4 => Some(LayoutLabel::Caption),
5 => Some(LayoutLabel::Footnote),
6 => Some(LayoutLabel::PageHeader),
7 => Some(LayoutLabel::PageFooter),
8 => Some(LayoutLabel::Table),
9 => Some(LayoutLabel::Figure),
10 => Some(LayoutLabel::Formula),
11 => Some(LayoutLabel::Code),
_ => None, }
}
}
#[cfg(feature = "docling-ffi")]
impl DocumentModel for LayoutModel {
type Input = image::DynamicImage;
type Output = LayoutPrediction;
fn predict(&mut self, input: &Self::Input) -> Result<Self::Output> {
let tensor = preprocessing::preprocess_for_layout(input)?;
let regions = self.run_inference(&tensor)?;
let (width, height) = input.dimensions();
Ok(LayoutPrediction {
regions,
page_width: width,
page_height: height,
})
}
fn name(&self) -> &str {
"LayoutModel"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore] fn test_load_model() {
let _result = LayoutModel::new("models/layout_model.onnx");
}
}