use image::imageops::FilterType;
use image::RgbImage;
use ort::session::Session;
use ort::value::Tensor;
pub const LABELS: [&str; 17] = [
"caption",
"footnote",
"formula",
"list_item",
"page_footer",
"page_header",
"picture",
"section_header",
"table",
"text",
"title",
"document_index",
"code",
"checkbox_selected",
"checkbox_unselected",
"form",
"key_value_region",
];
#[derive(Debug, Clone)]
pub struct Region {
pub label: &'static str,
pub score: f32,
pub l: f32,
pub t: f32,
pub r: f32,
pub b: f32,
}
const THRESHOLD: f32 = 0.3;
const SIDE: u32 = 640;
pub struct LayoutModel {
session: Session,
}
impl LayoutModel {
pub fn load() -> Result<Self, String> {
let path = std::env::var("DOCLING_LAYOUT_ONNX")
.unwrap_or_else(|_| "models/layout_heron.onnx".to_string());
let mut builder = Session::builder().map_err(|e| format!("layout: builder: {e}"))?;
let session = builder
.commit_from_file(&path)
.map_err(|e| format!("layout: load {path}: {e}"))?;
Ok(Self { session })
}
pub fn predict(
&mut self,
img: &RgbImage,
page_w: f32,
page_h: f32,
) -> Result<Vec<Region>, String> {
let resized = image::imageops::resize(img, SIDE, SIDE, FilterType::Triangle);
let n = (SIDE * SIDE) as usize;
let mut data = vec![0f32; 3 * n];
for (i, px) in resized.pixels().enumerate() {
data[i] = px[0] as f32 / 255.0;
data[n + i] = px[1] as f32 / 255.0;
data[2 * n + i] = px[2] as f32 / 255.0;
}
let input = Tensor::from_array(([1usize, 3, SIDE as usize, SIDE as usize], data))
.map_err(|e| format!("layout: input tensor: {e}"))?;
let outputs = self
.session
.run(ort::inputs!["pixel_values" => input])
.map_err(|e| format!("layout: inference: {e}"))?;
let (lshape, logits) = outputs["logits"]
.try_extract_tensor::<f32>()
.map_err(|e| format!("layout: extract logits: {e}"))?;
let (_, boxes) = outputs["pred_boxes"]
.try_extract_tensor::<f32>()
.map_err(|e| format!("layout: extract boxes: {e}"))?;
let num_queries = lshape[1] as usize;
let num_classes = lshape[2] as usize;
let mut scored: Vec<(f32, usize)> = (0..num_queries * num_classes)
.map(|idx| (sigmoid(logits[idx]), idx))
.collect();
scored.sort_unstable_by(|a, b| b.0.total_cmp(&a.0));
scored.truncate(num_queries);
let mut regions = Vec::new();
for (score, idx) in scored {
if score <= THRESHOLD {
continue;
}
let label_id = idx % num_classes;
let q = idx / num_classes;
let cx = boxes[q * 4];
let cy = boxes[q * 4 + 1];
let w = boxes[q * 4 + 2];
let h = boxes[q * 4 + 3];
let l = (cx - w / 2.0) * page_w;
let t = (cy - h / 2.0) * page_h;
let r = (cx + w / 2.0) * page_w;
let b = (cy + h / 2.0) * page_h;
regions.push(Region {
label: LABELS.get(label_id).copied().unwrap_or("text"),
score,
l,
t,
r,
b,
});
}
Ok(regions)
}
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}